From b9d38ee6378f3035f8dbad442223d3d9e2f3dc4f Mon Sep 17 00:00:00 2001 From: Frank Lei Date: Tue, 5 Dec 2017 10:43:33 +0800 Subject: APPBROWSER-312 Fully connected performance optimization Change-Id: Ie93fd630ebbad7b6ca8812cb5044b3f1908b45fd Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/111830 Reviewed-by: Stephen Li Tested-by: BSG Visual Compute Jenkins server to access repositories on http://mpd-gerrit.cambridge.arm.com Reviewed-by: Anthony Barbier --- .../kernels/GCDirectConvolutionLayerKernel.cpp | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) (limited to 'src/core/GLES_COMPUTE/kernels/GCDirectConvolutionLayerKernel.cpp') diff --git a/src/core/GLES_COMPUTE/kernels/GCDirectConvolutionLayerKernel.cpp b/src/core/GLES_COMPUTE/kernels/GCDirectConvolutionLayerKernel.cpp index b032bc5668..a7d721d035 100644 --- a/src/core/GLES_COMPUTE/kernels/GCDirectConvolutionLayerKernel.cpp +++ b/src/core/GLES_COMPUTE/kernels/GCDirectConvolutionLayerKernel.cpp @@ -53,7 +53,6 @@ template void GCDirectConvolutionLayerKernel::configure(const IGCTensor *input, const IGCTensor *weights, const IGCTensor *bias, IGCTensor *output, const PadStrideInfo &conv_info) { ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16, DataType::F32); - ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, weights, output); ARM_COMPUTE_ERROR_ON(weights->info()->dimension(2) != input->info()->dimension(2)); ARM_COMPUTE_ERROR_ON(weights->info()->dimension(0) != weights->info()->dimension(1)); ARM_COMPUTE_ERROR_ON(weights->info()->num_dimensions() > 4); @@ -68,6 +67,24 @@ void GCDirectConvolutionLayerKernel::configure(const IGCTensor *inp ARM_COMPUTE_ERROR_ON(bias->info()->num_dimensions() > 1); } + // Get convolved dimensions + unsigned int owidth = 0; + unsigned int oheight = 0; + std::tie(owidth, oheight) = scaled_dimensions(input->info()->dimension(0), input->info()->dimension(1), kernel_size, kernel_size, conv_info); + + TensorShape output_shape = input->info()->tensor_shape(); + output_shape.set(0, owidth); + output_shape.set(1, oheight); + output_shape.set(2, weights->info()->dimension(3)); + + // Output auto inizialitation if not yet initialized + auto_init_if_empty(*output->info(), output_shape, 1, input->info()->data_type(), input->info()->fixed_point_position()); + + ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, weights, output); + ARM_COMPUTE_ERROR_ON_MISMATCHING_DIMENSIONS(output->info()->tensor_shape(), output_shape); + ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, output); + ARM_COMPUTE_ERROR_ON_MISMATCHING_FIXED_POINT(input, output); + _conv_stride_x = std::get<0>(conv_info.stride()); _conv_stride_y = std::get<1>(conv_info.stride()); _conv_pad_x = std::get<0>(conv_info.pad()); -- cgit v1.2.1