From a72300a5e4d44cdadfe37f69e21f9bf628d19bb3 Mon Sep 17 00:00:00 2001 From: Gian Marco Iodice Date: Thu, 12 Apr 2018 11:41:26 +0100 Subject: COMPMID-1051 - Fix validate method in NEGEMMConvolutionLayer Change-Id: I10e8e1267a09246cac77e677f1c087bb1d80a61b Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/127517 Tested-by: Jenkins Reviewed-by: Anthony Barbier --- .../NEON/functions/NEGEMMConvolutionLayer.cpp | 100 +++++++++------------ 1 file changed, 42 insertions(+), 58 deletions(-) (limited to 'src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp') diff --git a/src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp b/src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp index c339947633..7f25c2e717 100644 --- a/src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp +++ b/src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp @@ -466,55 +466,6 @@ Status NEGEMMConvolutionLayer::validate(const ITensorInfo *input, const ITensorI optimised_kernel = true; } - // Reshape weights if needed - if(optimised_kernel) - { - if(are_weights_reshaped) - { - mat_weights_cols = weights_info.num_kernels(); - mat_weights_rows = weights->dimension(1); - } - else - { - TensorShape reshaped_weights_shape{ mat_weights_cols, mat_weights_rows }; - - // Create tensor to store the reshaped weights - reshaped_weights->set_tensor_shape(get_reshaped_weights_shape_conv(weights, append_bias, is_fully_connected_convolution)); - ARM_COMPUTE_RETURN_ON_ERROR(NEConvolutionLayerReshapeWeights::validate(weights, biases, reshaped_weights.get(), !is_fully_connected_convolution /* 1xW transpose */)); - weights = reshaped_weights.get(); - } - } - else - { - if(are_weights_reshaped) - { - const unsigned int transpose_width = 16 / input->element_size(); - mat_weights_cols = weights_info.num_kernels(); - mat_weights_rows = weights->dimension(0) / transpose_width + (append_bias ? 1 : 0); - } - else - { - TensorShape reshaped_weights_shape; - - if(is_fully_connected_convolution || is_quantized) - { - reshaped_weights_shape = TensorShape{ mat_weights_cols, mat_weights_rows }; - } - else - { - // Create tensor to store transposed weights - const float transpose_width = 16.0f / input->element_size(); - reshaped_weights_shape = TensorShape{ mat_weights_rows *static_cast(transpose_width), - static_cast(std::ceil(mat_weights_cols / transpose_width)) }; - } - - // Create tensor to store the reshaped weights - reshaped_weights->set_tensor_shape(get_reshaped_weights_shape_conv(weights, append_bias, is_fully_connected_convolution)); - ARM_COMPUTE_RETURN_ON_ERROR(NEConvolutionLayerReshapeWeights::validate(weights, biases, reshaped_weights.get(), !is_fully_connected_convolution /* 1xW transpose */)); - weights = reshaped_weights.get(); - } - } - // Validate im2col const unsigned int mat_input_cols = mat_weights_rows; const unsigned int mat_input_rows = conv_w * conv_h; @@ -531,19 +482,52 @@ Status NEGEMMConvolutionLayer::validate(const ITensorInfo *input, const ITensorI shape_gemm.set(1, mat_input_rows); TensorInfo gemm_output_info = input->clone()->set_tensor_shape(shape_gemm); - // Validate GEMM interleave and multiply - if(is_interleaved) + // Reshape weights if needed + if(optimised_kernel) { - TensorShape shape_interleaved = shape_im2col; - shape_interleaved.set(0, shape_interleaved.x() * 4); - shape_interleaved.set(1, std::ceil(shape_interleaved.y() / 4.f)); - TensorInfo input_interleaved_info = input->clone()->set_tensor_shape(shape_interleaved); - ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMInterleave4x4Kernel::validate(&im2_col_info, &input_interleaved_info)); - ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMMatrixMultiplyKernel::validate(&input_interleaved_info, weights, &gemm_output_info, 1.f, is_interleaved, GEMMReshapeInfo())); + ARM_COMPUTE_RETURN_ERROR_ON(are_weights_reshaped); + + // Create tensor to store the reshaped weights + reshaped_weights->set_tensor_shape(get_reshaped_weights_shape_conv(weights, append_bias, is_fully_connected_convolution)); + ARM_COMPUTE_RETURN_ON_ERROR(NEConvolutionLayerReshapeWeights::validate(weights, biases, reshaped_weights.get(), !is_fully_connected_convolution /* 1xW transpose */)); } else { - ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMMatrixMultiplyKernel::validate(&im2_col_info, weights, &gemm_output_info, 1.f, is_interleaved, GEMMReshapeInfo())); + TensorShape reshaped_weights_shape; + + if(is_fully_connected_convolution || is_quantized) + { + reshaped_weights_shape = TensorShape{ mat_weights_cols, mat_weights_rows }; + } + else + { + // Create tensor to store transposed weights + const float transpose_width = 16.0f / input->element_size(); + reshaped_weights_shape = TensorShape{ mat_weights_rows *static_cast(transpose_width), + static_cast(std::ceil(mat_weights_cols / transpose_width)) }; + } + + // Create tensor to store the reshaped weights + reshaped_weights->set_tensor_shape(get_reshaped_weights_shape_conv(weights, append_bias, is_fully_connected_convolution)); + ARM_COMPUTE_RETURN_ON_ERROR(NEConvolutionLayerReshapeWeights::validate(weights, biases, reshaped_weights.get(), !is_fully_connected_convolution /* 1xW transpose */)); + weights = reshaped_weights.get(); + + // Validate GEMM interleave and multiply + if(is_interleaved) + { + TensorShape shape_interleaved = shape_im2col; + shape_interleaved.set(0, shape_interleaved.x() * 4); + shape_interleaved.set(1, std::ceil(shape_interleaved.y() / 4.f)); + TensorInfo input_interleaved_info = input->clone()->set_tensor_shape(shape_interleaved); + ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMInterleave4x4Kernel::validate(&im2_col_info, &input_interleaved_info)); + ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMMatrixMultiplyKernel::validate(&input_interleaved_info, weights, &gemm_output_info, 1.f, is_interleaved, GEMMReshapeInfo(shape_im2col[1], // m + weights->tensor_shape()[0], // n + shape_im2col[0]) /* k */)); + } + else + { + ARM_COMPUTE_RETURN_ON_ERROR(NEGEMMMatrixMultiplyKernel::validate(&im2_col_info, weights, &gemm_output_info, 1.f, is_interleaved, GEMMReshapeInfo())); + } } ARM_COMPUTE_RETURN_ON_ERROR(NECol2ImKernel::validate(&gemm_output_info, output, Size2D(conv_w, conv_h))); -- cgit v1.2.1