diff options
author | Gian Marco Iodice <gianmarco.iodice@arm.com> | 2018-04-12 11:41:26 +0100 |
---|---|---|
committer | Anthony Barbier <anthony.barbier@arm.com> | 2018-11-02 16:49:37 +0000 |
commit | a72300a5e4d44cdadfe37f69e21f9bf628d19bb3 (patch) | |
tree | 08170075522e84438693d2c79e10a87d9ef68b09 /src | |
parent | 95a3f55bbfb954aee58b2a13276035aa7b9e36a2 (diff) | |
download | ComputeLibrary-a72300a5e4d44cdadfe37f69e21f9bf628d19bb3.tar.gz |
COMPMID-1051 - Fix validate method in NEGEMMConvolutionLayer
Change-Id: I10e8e1267a09246cac77e677f1c087bb1d80a61b
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/127517
Tested-by: Jenkins <bsgcomp@arm.com>
Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Diffstat (limited to 'src')
-rw-r--r-- | src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp | 100 |
1 files changed, 42 insertions, 58 deletions
diff --git a/src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp b/src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp index c339947633..7f25c2e717 100644 --- a/src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp +++ b/src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp @@ -466,55 +466,6 @@ Status NEGEMMConvolutionLayer::validate(const ITensorInfo *input, const ITensorI optimised_kernel = true; } - // Reshape weights if needed - if(optimised_kernel) - { - if(are_weights_reshaped) - { - mat_weights_cols = weights_info.num_kernels(); - mat_weights_rows = weights->dimension(1); - } - else - { - TensorShape reshaped_weights_shape{ mat_weights_cols, mat_weights_rows }; - - // Create tensor to store the reshaped weights - reshaped_weights->set_tensor_shape(get_reshaped_weights_shape_conv(weights, append_bias, is_fully_connected_convolution)); - ARM_COMPUTE_RETURN_ON_ERROR(NEConvolutionLayerReshapeWeights::validate(weights, biases, reshaped_weights.get(), !is_fully_connected_convolution /* 1xW transpose */)); - weights = reshaped_weights.get(); - } - } - else - { - if(are_weights_reshaped) - { - const unsigned int transpose_width = 16 / input->element_size(); - mat_weights_cols = weights_info.num_kernels(); - mat_weights_rows = weights->dimension(0) / transpose_width + (append_bias ? 1 : 0); - } - else - { - TensorShape reshaped_weights_shape; - - if(is_fully_connected_convolution || is_quantized) - { - reshaped_weights_shape = TensorShape{ mat_weights_cols, mat_weights_rows }; - } - else - { - // Create tensor to store transposed weights - const float transpose_width = 16.0f / input->element_size(); - reshaped_weights_shape = TensorShape{ mat_weights_rows *static_cast<unsigned int>(transpose_width), - static_cast<unsigned int>(std::ceil(mat_weights_cols / transpose_width)) }; - } - - // Create tensor to store the reshaped weights - reshaped_weights->set_tensor_shape(get_reshaped_weights_shape_conv(weights, append_bias, is_fully_connected_convolution)); - ARM_COMPUTE_RETURN_ON_ERROR(NEConvolutionLayerReshapeWeights::validate(weights, biases, reshaped_weights.get(), !is_fully_connected_convolution /* 1xW transpose */)); - weights = reshaped_weights.get(); - } - } - // Validate im2col const unsigned int mat_input_cols = mat_weights_rows; const unsigned int mat_input_rows = conv_w * conv_h; @@ -531,19 +482,52 @@ Status NEGEMMConvolutionLayer::validate(const ITensorInfo *input, const ITensorI shape_gemm.set(1, mat_input_rows); TensorInfo gemm_output_info = input->clone()->set_tensor_shape(shape_gemm); - // Validate GEMM interleave and multiply - if(is_interleaved) + // Reshape weights if needed + if(optimised_kernel) { - TensorShape shape_interleaved = shape_im2col; - shape_interleaved.set(0, shape_interleaved.x() * 4); - shape_interleaved.set(1, std::ceil(shape_interleaved.y() / 4.f)); - TensorInfo input_interleaved_info = input->clone()->set_tensor_shape(shape_interleaved); - ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMInterleave4x4Kernel::validate(&im2_col_info, &input_interleaved_info)); - ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMMatrixMultiplyKernel::validate(&input_interleaved_info, weights, &gemm_output_info, 1.f, is_interleaved, GEMMReshapeInfo())); + ARM_COMPUTE_RETURN_ERROR_ON(are_weights_reshaped); + + // Create tensor to store the reshaped weights + reshaped_weights->set_tensor_shape(get_reshaped_weights_shape_conv(weights, append_bias, is_fully_connected_convolution)); + ARM_COMPUTE_RETURN_ON_ERROR(NEConvolutionLayerReshapeWeights::validate(weights, biases, reshaped_weights.get(), !is_fully_connected_convolution /* 1xW transpose */)); } else { - ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMMatrixMultiplyKernel::validate(&im2_col_info, weights, &gemm_output_info, 1.f, is_interleaved, GEMMReshapeInfo())); + TensorShape reshaped_weights_shape; + + if(is_fully_connected_convolution || is_quantized) + { + reshaped_weights_shape = TensorShape{ mat_weights_cols, mat_weights_rows }; + } + else + { + // Create tensor to store transposed weights + const float transpose_width = 16.0f / input->element_size(); + reshaped_weights_shape = TensorShape{ mat_weights_rows *static_cast<unsigned int>(transpose_width), + static_cast<unsigned int>(std::ceil(mat_weights_cols / transpose_width)) }; + } + + // Create tensor to store the reshaped weights + reshaped_weights->set_tensor_shape(get_reshaped_weights_shape_conv(weights, append_bias, is_fully_connected_convolution)); + ARM_COMPUTE_RETURN_ON_ERROR(NEConvolutionLayerReshapeWeights::validate(weights, biases, reshaped_weights.get(), !is_fully_connected_convolution /* 1xW transpose */)); + weights = reshaped_weights.get(); + + // Validate GEMM interleave and multiply + if(is_interleaved) + { + TensorShape shape_interleaved = shape_im2col; + shape_interleaved.set(0, shape_interleaved.x() * 4); + shape_interleaved.set(1, std::ceil(shape_interleaved.y() / 4.f)); + TensorInfo input_interleaved_info = input->clone()->set_tensor_shape(shape_interleaved); + ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMInterleave4x4Kernel::validate(&im2_col_info, &input_interleaved_info)); + ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMMatrixMultiplyKernel::validate(&input_interleaved_info, weights, &gemm_output_info, 1.f, is_interleaved, GEMMReshapeInfo(shape_im2col[1], // m + weights->tensor_shape()[0], // n + shape_im2col[0]) /* k */)); + } + else + { + ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMMatrixMultiplyKernel::validate(&im2_col_info, weights, &gemm_output_info, 1.f, is_interleaved, GEMMReshapeInfo())); + } } ARM_COMPUTE_RETURN_ON_ERROR(NECol2ImKernel::validate(&gemm_output_info, output, Size2D(conv_w, conv_h))); |