aboutsummaryrefslogtreecommitdiff
path: root/src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp
diff options
context:
space:
mode:
authorMichalis Spyrou <michalis.spyrou@arm.com>2018-04-23 15:17:31 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:51:37 +0000
commite250389ed6d78153a55382fa5b3519c151bfd79f (patch)
tree80c63793769ad18fd0406e7f8b40840aed7ac3ce /src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp
parent79ffadebd8dff7eaecbcfa3a28106736f240f1c5 (diff)
downloadComputeLibrary-e250389ed6d78153a55382fa5b3519c151bfd79f.tar.gz
COMPMID-810 Add NHWC data format support for NEON convolution
Change-Id: I2a7b49a12da7f3bc3f04749243b1dc111160de6e Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/129348 Tested-by: Jenkins <bsgcomp@arm.com> Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Diffstat (limited to 'src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp')
-rw-r--r--src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp242
1 files changed, 151 insertions, 91 deletions
diff --git a/src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp b/src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp
index 5a35463365..a5f30557a0 100644
--- a/src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp
+++ b/src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp
@@ -109,6 +109,14 @@ Status NEConvolutionLayerReshapeWeights::validate(const ITensorInfo *weights, co
ARM_COMPUTE_RETURN_ERROR_ON(biases->num_dimensions() > 1);
}
+ // Checks performed when biases are present
+ if(append_bias)
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(weights, biases);
+ ARM_COMPUTE_RETURN_ERROR_ON(biases->dimension(0) != weights->dimension(3));
+ ARM_COMPUTE_RETURN_ERROR_ON(biases->num_dimensions() > 1);
+ }
+
if(transpose1xW)
{
TensorInfo weights_reshaped = weights->clone()->set_tensor_shape(get_reshaped_weights_shape(weights, append_bias));
@@ -159,7 +167,7 @@ TensorShape get_reshaped_weights_shape_conv(const ITensorInfo *weights, bool app
Status validate_and_initialize_values(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const PadStrideInfo &conv_info, const WeightsInfo &weights_info,
const ActivationLayerInfo &act_info, DataType &dt,
- bool &append_bias,
+ bool &append_bias, bool &skip_im2col,
bool &are_weights_reshaped, unsigned int &kernel_width, unsigned int &kernel_height,
bool &is_fully_connected_convolution, bool &is_interleaved, bool &is_quantized, bool &is_activationlayer_enabled,
unsigned int &mat_weights_cols, unsigned int &mat_weights_rows,
@@ -168,9 +176,17 @@ Status validate_and_initialize_values(const ITensorInfo *input, const ITensorInf
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QASYMM8, DataType::QS16, DataType::F16, DataType::F32);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, weights);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input, weights);
- ARM_COMPUTE_RETURN_ERROR_ON(!weights_info.are_reshaped() && weights->dimension(2) != input->dimension(2));
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_LAYOUT(input, weights);
+
+ DataLayout data_layout = input->data_layout();
+ const int idx_width = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
+ const int idx_height = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT);
+ const int idx_channel = get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL);
+
+ ARM_COMPUTE_RETURN_ERROR_ON(!weights_info.are_reshaped() && weights->dimension(idx_channel) != input->dimension(idx_channel));
ARM_COMPUTE_RETURN_ERROR_ON(weights->num_dimensions() > 4);
ARM_COMPUTE_RETURN_ERROR_ON(weights_info.are_reshaped() && is_data_type_quantized_asymmetric(input->data_type()));
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(data_layout == DataLayout::NHWC && input->data_type() != DataType::F32, "NHWC is only supported for FP32 data type.");
dt = input->data_type();
is_quantized = is_data_type_quantized_asymmetric(dt);
@@ -190,14 +206,16 @@ Status validate_and_initialize_values(const ITensorInfo *input, const ITensorInf
ARM_COMPUTE_RETURN_ERROR_ON(biases->num_dimensions() > 1);
}
+ // If we have 1x1 convolution and data layout is NHWC we can disable im2col
append_bias = (biases != nullptr) && (!is_quantized);
are_weights_reshaped = weights_info.are_reshaped();
- kernel_width = (are_weights_reshaped) ? weights_info.kernel_size().first : weights->dimension(0);
- kernel_height = (are_weights_reshaped) ? weights_info.kernel_size().second : weights->dimension(1);
+ kernel_width = (are_weights_reshaped) ? weights_info.kernel_size().first : weights->dimension(idx_width);
+ kernel_height = (are_weights_reshaped) ? weights_info.kernel_size().second : weights->dimension(idx_height);
mat_weights_cols = weights->dimension(3);
- mat_weights_rows = weights->dimension(0) * weights->dimension(1) * weights->dimension(2) + (append_bias ? 1 : 0);
+ mat_weights_rows = weights->dimension(idx_width) * weights->dimension(idx_height) * weights->dimension(idx_channel) + ((append_bias && !skip_im2col) ? 1 : 0);
+ skip_im2col = (data_layout == DataLayout::NHWC && kernel_width == 1 && kernel_height == 1);
- std::tie(conv_w, conv_h) = scaled_dimensions(input->dimension(0), input->dimension(1), kernel_width, kernel_height,
+ std::tie(conv_w, conv_h) = scaled_dimensions(input->dimension(idx_width), input->dimension(idx_height), kernel_width, kernel_height,
conv_info, dilation);
// Check if its a "fully connected" convolution
@@ -211,9 +229,9 @@ Status validate_and_initialize_values(const ITensorInfo *input, const ITensorInf
NEGEMMConvolutionLayer::NEGEMMConvolutionLayer(const std::shared_ptr<IMemoryManager> &memory_manager)
: _asm_glue(), _memory_group(memory_manager), _input_im2col_kernel(), _input_interleave_kernel(), _reshape_weights(), _mm_kernel(), _mm_gemmlowp(memory_manager), _gemmlowp_output_stage(),
- _output_col2im_kernel(), _activationlayer_function(), _original_weights(nullptr), _input_im2col_reshaped(), _input_interleaved_reshaped(), _weights_reshaped(), _gemm_output(), _tmp_output(),
- _workspace(), _B_pretransposed(), _append_bias(false), _is_fully_connected_convolution(false), _are_weights_reshaped(false), _is_quantized(false), _is_interleaved(false),
- _is_activationlayer_enabled(false)
+ _output_col2im_kernel(), _activationlayer_function(), _add_bias_kernel(), _original_weights(nullptr), _input_im2col_reshaped(), _input_interleaved_reshaped(), _weights_reshaped(), _gemm_output(),
+ _tmp_output(), _workspace(), _B_pretransposed(), _data_layout(DataLayout::NCHW), _append_bias(false), _is_fully_connected_convolution(false), _are_weights_reshaped(false), _is_quantized(false),
+ _is_interleaved(false), _is_activationlayer_enabled(false), _skip_im2col(false)
{
}
@@ -255,7 +273,13 @@ void NEGEMMConvolutionLayer::configure(const ITensor *input, const ITensor *weig
unsigned int conv_w = 0;
unsigned int conv_h = 0;
- Status status = validate_and_initialize_values(input->info(), weights->info(), (biases == nullptr) ? nullptr : biases->info(), conv_info, weights_info, act_info, dt, _append_bias,
+ _data_layout = input->info()->data_layout();
+ const bool is_nhwc = _data_layout == DataLayout::NHWC;
+ const int idx_width = get_data_layout_dimension_index(_data_layout, DataLayoutDimension::WIDTH);
+ const int idx_height = get_data_layout_dimension_index(_data_layout, DataLayoutDimension::HEIGHT);
+ const int idx_channel = get_data_layout_dimension_index(_data_layout, DataLayoutDimension::CHANNEL);
+
+ Status status = validate_and_initialize_values(input->info(), weights->info(), (biases == nullptr) ? nullptr : biases->info(), conv_info, weights_info, act_info, dt, _append_bias, _skip_im2col,
_are_weights_reshaped,
kernel_width, kernel_height,
_is_fully_connected_convolution, _is_interleaved, _is_quantized, _is_activationlayer_enabled,
@@ -272,20 +296,12 @@ void NEGEMMConvolutionLayer::configure(const ITensor *input, const ITensor *weig
// Reshape weights if needed
if(run_optimised)
{
- if(_are_weights_reshaped)
- {
- mat_weights_cols = weights_info.num_kernels();
- mat_weights_rows = weights->info()->dimension(1);
- }
- else
- {
- TensorShape reshaped_weights_shape{ mat_weights_cols, mat_weights_rows };
+ TensorShape reshaped_weights_shape{ mat_weights_cols, mat_weights_rows };
- // Create tensor to store the reshaped weights
- _weights_reshaped.allocator()->init(TensorInfo(reshaped_weights_shape, 1, dt, fixed_point_position));
- _reshape_weights.configure(weights, biases, &_weights_reshaped, false /* 1xW transpose */);
- weights = &_weights_reshaped;
- }
+ // Create tensor to store the reshaped weights
+ _weights_reshaped.allocator()->init(TensorInfo(reshaped_weights_shape, 1, dt, fixed_point_position));
+ _reshape_weights.configure(weights, biases, &_weights_reshaped, false /* 1xW transpose */);
+ weights = &_weights_reshaped;
}
else
{
@@ -294,12 +310,12 @@ void NEGEMMConvolutionLayer::configure(const ITensor *input, const ITensor *weig
if(_is_fully_connected_convolution || _is_quantized)
{
mat_weights_cols = weights_info.num_kernels();
- mat_weights_rows = weights->info()->dimension(1);
+ mat_weights_rows = weights->info()->dimension(idx_height);
}
else
{
mat_weights_cols = weights_info.num_kernels();
- mat_weights_rows = weights_info.kernel_size().first * weights_info.kernel_size().second * input->info()->dimension(2) + (_append_bias ? 1 : 0);
+ mat_weights_rows = weights_info.kernel_size().first * weights_info.kernel_size().second * input->info()->dimension(idx_channel) + (_append_bias ? 1 : 0);
}
}
else
@@ -325,48 +341,56 @@ void NEGEMMConvolutionLayer::configure(const ITensor *input, const ITensor *weig
}
}
- // Create tensor to store im2col reshaped inputs
- const unsigned int mat_input_cols = mat_weights_rows;
- const unsigned int mat_input_rows = conv_w * conv_h;
-
- TensorShape shape_im2col(input->info()->tensor_shape());
- shape_im2col.set(0, mat_input_cols);
- shape_im2col.set(1, mat_input_rows);
- shape_im2col.set(2, 1);
- _input_im2col_reshaped.allocator()->init(input->info()->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(shape_im2col));
- _memory_group.manage(&_input_im2col_reshaped);
+ // In case we skip im2col we have to add bias
+ if(!_skip_im2col)
+ {
+ const unsigned int mat_input_cols = mat_weights_rows;
+ const unsigned int mat_input_rows = conv_w * conv_h;
+
+ // Create tensor to store im2col reshaped inputs
+ TensorShape shape_im2col(input->info()->tensor_shape());
+ shape_im2col.set(0, mat_input_cols);
+ shape_im2col.set(1, mat_input_rows);
+ shape_im2col.set(2, 1);
+ _input_im2col_reshaped.allocator()->init(input->info()->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(shape_im2col));
+ _memory_group.manage(&_input_im2col_reshaped);
+
+ // Create tensor (interleave) to prepare input tensor for GEMM
+ if(!_is_fully_connected_convolution && !run_optimised && _is_interleaved)
+ {
+ TensorShape shape_interleaved(shape_im2col);
+ shape_interleaved.set(idx_width, shape_interleaved.x() * 4);
+ shape_interleaved.set(idx_height, std::ceil(shape_interleaved[idx_height] / 4.f));
+ _input_interleaved_reshaped.allocator()->init(input->info()->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(shape_interleaved));
+ _memory_group.manage(&_input_interleaved_reshaped);
+ }
- // Create tensor (interleave) to prepare input tensor for GEMM
- if(!_is_fully_connected_convolution && !run_optimised && _is_interleaved)
+ // Create GEMM output tensor
+ TensorShape shape_gemm(_input_im2col_reshaped.info()->tensor_shape());
+ shape_gemm.set(0, mat_weights_cols);
+ shape_gemm.set(1, mat_input_rows);
+ const DataType gemm_data_type = _is_quantized ? DataType::S32 : dt;
+ // GEMM output should be S32 for acquiring raw integer accumulator without quantized postprocessing for quantized asymmetric input.
+ TensorInfo info_gemm(shape_gemm, 1, gemm_data_type, input->info()->fixed_point_position());
+ info_gemm.set_quantization_info(output->info()->quantization_info());
+ _gemm_output.allocator()->init(info_gemm);
+
+ // FIXME: enabling memory manager for _gemm_output gives incorrect results (maybe bound to the assembly kernel in GEMMLowp?)
+ // _memory_group.manage(&_gemm_output);
+
+ // Configure im2col
+ _input_im2col_kernel.configure(input, &_input_im2col_reshaped, Size2D(kernel_width, kernel_height), conv_info, _append_bias, false, false, dilation);
+ }
+ else if(_append_bias)
{
- TensorShape shape_interleaved(shape_im2col);
- shape_interleaved.set(0, shape_interleaved.x() * 4);
- shape_interleaved.set(1, std::ceil(shape_interleaved.y() / 4.f));
- _input_interleaved_reshaped.allocator()->init(input->info()->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(shape_interleaved));
- _memory_group.manage(&_input_interleaved_reshaped);
+ // Configure add bias kernel
+ _add_bias_kernel.configure(output, biases, output, ConvertPolicy::SATURATE);
}
- // Create GEMM output tensor
- TensorShape shape_gemm(_input_im2col_reshaped.info()->tensor_shape());
- shape_gemm.set(0, mat_weights_cols);
- shape_gemm.set(1, mat_input_rows);
- const DataType gemm_data_type = _is_quantized ? DataType::S32 : dt;
- // GEMM output should be S32 for acquiring raw integer accumulator without quantized postprocessing for quantized asymmetric input.
- TensorInfo info_gemm(shape_gemm, 1, gemm_data_type, input->info()->fixed_point_position());
- info_gemm.set_quantization_info(output->info()->quantization_info());
- _gemm_output.allocator()->init(info_gemm);
-
- // FIXME: enabling memory manager for _gemm_output gives incorrect results (maybe bound to the assembly kernel in GEMMLowp?)
- // _memory_group.manage(&_gemm_output);
-
- // Configure kernels
- // Configure im2col
- _input_im2col_kernel.configure(input, &_input_im2col_reshaped, Size2D(kernel_width, kernel_height), conv_info, _append_bias, false, false, dilation);
-
// Configure matrix multiply
if(run_optimised)
{
- if(!setup_assembly_kernel(&_input_im2col_reshaped, weights, &_gemm_output, 1.f, 0.f, true, _workspace, _B_pretransposed, _memory_group, _asm_glue))
+ if(!setup_assembly_kernel(_skip_im2col ? input : &_input_im2col_reshaped, weights, is_nhwc ? output : &_gemm_output, 1.f, 0.f, true, _workspace, _B_pretransposed, _memory_group, _asm_glue))
{
ARM_COMPUTE_ERROR("setup_assembly_kernel failed.");
}
@@ -379,8 +403,8 @@ void NEGEMMConvolutionLayer::configure(const ITensor *input, const ITensor *weig
_input_interleave_kernel.configure(&_input_im2col_reshaped, &_input_interleaved_reshaped);
// Configure GEMM
- configure_mm(&_input_interleaved_reshaped, weights, &_gemm_output, _is_interleaved, GEMMReshapeInfo(_input_im2col_reshaped.info()->dimension(1), 0 /* no transpose */,
- _input_im2col_reshaped.info()->dimension(0)));
+ configure_mm(&_input_interleaved_reshaped, weights, &_gemm_output, _is_interleaved, GEMMReshapeInfo(_input_im2col_reshaped.info()->dimension(idx_height), 0 /* no transpose */,
+ _input_im2col_reshaped.info()->dimension(idx_width)));
_input_interleaved_reshaped.allocator()->allocate();
}
else
@@ -389,29 +413,36 @@ void NEGEMMConvolutionLayer::configure(const ITensor *input, const ITensor *weig
}
}
- _input_im2col_reshaped.allocator()->allocate();
-
- // Configure output stage for quantized case
- if(_is_quantized)
+ if(!_skip_im2col)
{
- const QuantizationInfo output_quant_info = (output->info()->total_size() == 0) ? input->info()->quantization_info() : output->info()->quantization_info();
+ _input_im2col_reshaped.allocator()->allocate();
- float multiplier = input->info()->quantization_info().scale * weights->info()->quantization_info().scale / output_quant_info.scale;
- int output_multiplier, output_shift;
- quantization::calculate_quantized_multiplier_less_than_one(multiplier, &output_multiplier, &output_shift);
- _memory_group.manage(&_tmp_output);
- _gemmlowp_output_stage.configure(&_gemm_output, biases, &_tmp_output, output_multiplier, output_shift, output_quant_info.offset);
- }
+ // Configure output stage for quantized case
+ if(_is_quantized)
+ {
+ const QuantizationInfo output_quant_info = (output->info()->total_size() == 0) ? input->info()->quantization_info() : output->info()->quantization_info();
- // Configure Col2Im
- _output_col2im_kernel.configure(_is_quantized ? &_tmp_output : &_gemm_output, output, Size2D(conv_w, conv_h));
- if(_is_quantized)
- {
- _tmp_output.allocator()->allocate();
+ float multiplier = input->info()->quantization_info().scale * weights->info()->quantization_info().scale / output_quant_info.scale;
+ int output_multiplier, output_shift;
+ quantization::calculate_quantized_multiplier_less_than_one(multiplier, &output_multiplier, &output_shift);
+ _memory_group.manage(&_tmp_output);
+ _gemmlowp_output_stage.configure(&_gemm_output, biases, &_tmp_output, output_multiplier, output_shift, output_quant_info.offset);
+ }
+
+ // Configure Col2Im
+ if(!is_nhwc)
+ {
+ _output_col2im_kernel.configure(_is_quantized ? &_tmp_output : &_gemm_output, output, Size2D(conv_w, conv_h));
+ }
+
+ if(_is_quantized)
+ {
+ _tmp_output.allocator()->allocate();
+ }
+ _gemm_output.allocator()->allocate();
}
- _gemm_output.allocator()->allocate();
- ARM_COMPUTE_ERROR_ON_MSG((output->info()->dimension(0) != conv_w) || (output->info()->dimension(1) != conv_h), "Output shape does not match the expected one");
+ ARM_COMPUTE_ERROR_ON_MSG((output->info()->dimension(idx_width) != conv_w) || (output->info()->dimension(idx_height) != conv_h), "Output shape does not match the expected one");
// Allocate intermediate tensor
if(!_are_weights_reshaped)
@@ -433,6 +464,7 @@ Status NEGEMMConvolutionLayer::validate(const ITensorInfo *input, const ITensorI
DataType dt{};
bool append_bias{};
+ bool skip_im2col{};
bool are_weights_reshaped{};
bool is_fully_connected_convolution{};
bool is_interleaved{};
@@ -445,7 +477,12 @@ Status NEGEMMConvolutionLayer::validate(const ITensorInfo *input, const ITensorI
unsigned int conv_w = 0;
unsigned int conv_h = 0;
- Status status = validate_and_initialize_values(input, weights, biases, conv_info, weights_info, act_info, dt, append_bias, are_weights_reshaped, kernel_width, kernel_height,
+ const DataLayout data_layout = input->data_layout();
+ const bool is_nhwc = data_layout == DataLayout::NHWC;
+ const int idx_width = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
+ const int idx_height = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT);
+
+ Status status = validate_and_initialize_values(input, weights, biases, conv_info, weights_info, act_info, dt, append_bias, skip_im2col, are_weights_reshaped, kernel_width, kernel_height,
is_fully_connected_convolution, is_interleaved, is_quantized, is_activationlayer_enabled, mat_weights_cols, mat_weights_rows,
conv_w, conv_h, dilation);
@@ -461,7 +498,6 @@ Status NEGEMMConvolutionLayer::validate(const ITensorInfo *input, const ITensorI
optimised_kernel = true;
}
- // Validate im2col
const unsigned int mat_input_cols = mat_weights_rows;
const unsigned int mat_input_rows = conv_w * conv_h;
TensorShape shape_im2col = input->tensor_shape();
@@ -469,7 +505,17 @@ Status NEGEMMConvolutionLayer::validate(const ITensorInfo *input, const ITensorI
shape_im2col.set(1, mat_input_rows);
shape_im2col.set(2, 1);
TensorInfo im2_col_info = input->clone()->set_tensor_shape(shape_im2col);
- ARM_COMPUTE_RETURN_ON_ERROR(NEIm2ColKernel::validate(input, &im2_col_info, kernel_weights, conv_info, append_bias, false, false, dilation));
+
+ if(!skip_im2col)
+ {
+ // Validate im2col
+ ARM_COMPUTE_RETURN_ON_ERROR(NEIm2ColKernel::validate(input, &im2_col_info, kernel_weights, conv_info, append_bias, false, false, dilation));
+ }
+ else if(append_bias)
+ {
+ // Validate add bias kernel
+ ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAdditionKernel::validate(output, biases, output, ConvertPolicy::SATURATE));
+ }
// Create GEMM output tensor
TensorShape shape_gemm(im2_col_info.tensor_shape());
@@ -511,8 +557,8 @@ Status NEGEMMConvolutionLayer::validate(const ITensorInfo *input, const ITensorI
if(is_interleaved)
{
TensorShape shape_interleaved = shape_im2col;
- shape_interleaved.set(0, shape_interleaved.x() * 4);
- shape_interleaved.set(1, std::ceil(shape_interleaved.y() / 4.f));
+ shape_interleaved.set(idx_width, shape_interleaved.x() * 4);
+ shape_interleaved.set(idx_height, std::ceil(shape_interleaved.y() / 4.f));
TensorInfo input_interleaved_info = input->clone()->set_tensor_shape(shape_interleaved);
ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMInterleave4x4Kernel::validate(&im2_col_info, &input_interleaved_info));
ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMMatrixMultiplyKernel::validate(&input_interleaved_info, weights, &gemm_output_info, 1.f, is_interleaved, GEMMReshapeInfo(shape_im2col[1], // m
@@ -524,10 +570,12 @@ Status NEGEMMConvolutionLayer::validate(const ITensorInfo *input, const ITensorI
ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMMatrixMultiplyKernel::validate(&im2_col_info, weights, &gemm_output_info, 1.f, is_interleaved, GEMMReshapeInfo()));
}
}
+ if(!is_nhwc)
+ {
+ ARM_COMPUTE_RETURN_ON_ERROR(NECol2ImKernel::validate(&gemm_output_info, output, Size2D(conv_w, conv_h)));
+ }
- ARM_COMPUTE_RETURN_ON_ERROR(NECol2ImKernel::validate(&gemm_output_info, output, Size2D(conv_w, conv_h)));
-
- ARM_COMPUTE_RETURN_ERROR_ON_MSG((output->dimension(0) != conv_w) || (output->dimension(1) != conv_h), "Output shape does not match the expected one");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG((output->dimension(idx_width) != conv_w) || (output->dimension(idx_height) != conv_h), "Output shape does not match the expected one");
if(act_info.enabled())
{
@@ -553,8 +601,12 @@ void NEGEMMConvolutionLayer::run()
_memory_group.acquire();
- // Run input reshaping
- NEScheduler::get().schedule(&_input_im2col_kernel, Window::DimY);
+ if(!_skip_im2col)
+ {
+ // Run input reshaping
+ unsigned int _y_dim = get_data_layout_dimension_index(_data_layout, DataLayoutDimension::HEIGHT);
+ NEScheduler::get().schedule(&_input_im2col_kernel, _y_dim);
+ }
// Runs matrix multiply on reshaped matrices
if(_asm_glue._optimised_kernel != nullptr)
@@ -585,6 +637,11 @@ void NEGEMMConvolutionLayer::run()
}
}
+ if(_skip_im2col && _append_bias)
+ {
+ NEScheduler::get().schedule(&_add_bias_kernel, Window::DimY);
+ }
+
// Run output stage for quantized case
if(_is_quantized)
{
@@ -592,7 +649,10 @@ void NEGEMMConvolutionLayer::run()
}
// Reshape output matrix
- NEScheduler::get().schedule(&_output_col2im_kernel, Window::DimY);
+ if(_data_layout == DataLayout::NCHW)
+ {
+ NEScheduler::get().schedule(&_output_col2im_kernel, Window::DimY);
+ }
if(_is_activationlayer_enabled)
{