From aa95ddc2abb7cef0b2edd03f7c4c9d9c6b9d7cf4 Mon Sep 17 00:00:00 2001 From: Georgios Pinitas Date: Tue, 21 Jul 2020 22:45:13 +0100 Subject: COMPMID-3535: 9x9 Direct convolution support for CL and NHWC * Supported strides 1 and 2 Signed-off-by: Georgios Pinitas Change-Id: I4b9f087c0c328234159b2d1eacc2e465b3bb3c54 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/3603 Reviewed-by: Michele Di Giorgio Comments-Addressed: Arm Jenkins Tested-by: Arm Jenkins --- src/core/CL/kernels/CLDirectConvolutionLayerKernel.cpp | 18 ++++-------------- 1 file changed, 4 insertions(+), 14 deletions(-) (limited to 'src/core/CL/kernels') diff --git a/src/core/CL/kernels/CLDirectConvolutionLayerKernel.cpp b/src/core/CL/kernels/CLDirectConvolutionLayerKernel.cpp index 0bf4afd81c..4acbe2dff8 100644 --- a/src/core/CL/kernels/CLDirectConvolutionLayerKernel.cpp +++ b/src/core/CL/kernels/CLDirectConvolutionLayerKernel.cpp @@ -60,20 +60,9 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *weights, "Weights feature map dimension should match the respective input's one"); ARM_COMPUTE_RETURN_ERROR_ON_MSG(weights->num_dimensions() > 4, "Weights can be at most 4 dimensional"); ARM_COMPUTE_RETURN_ERROR_ON_MSG((weights->dimension(width_idx) == 1) && std::get<0>(conv_info.stride()) > 3, "Strides larger than 3 not supported for 1x1 convolution."); - ARM_COMPUTE_RETURN_ERROR_ON_MSG((weights->dimension(width_idx) == 3 || weights->dimension(width_idx) == 5) && std::get<0>(conv_info.stride()) > 2, - "Strides larger than 2 not supported for 3x3 convolution."); - - const auto data_type = input->data_type(); - - if(weights->dimension(width_idx) == 9) - { - const auto supported_data_layout = is_data_type_quantized(data_type) ? DataLayout::NCHW : DataLayout::NHWC; - const auto error_message = std::string("Only " + string_from_data_layout(supported_data_layout) + " layout is supported for 9x9 convolution with " + string_from_data_type( - data_type) - + " type"); - - ARM_COMPUTE_RETURN_ERROR_ON_MSG((supported_data_layout != data_layout), error_message.c_str()); - } + ARM_COMPUTE_RETURN_ERROR_ON_MSG((weights->dimension(width_idx) == 3 || weights->dimension(width_idx) == 5 || weights->dimension(width_idx) == 9) + && std::get<0>(conv_info.stride()) > 2, + "Strides larger than 2 not supported for 3x3, 5x5, 9x9 convolution."); if(biases != nullptr) { @@ -99,6 +88,7 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *weights, ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output); } + const auto data_type = input->data_type(); if(is_data_type_quantized(data_type)) { const UniformQuantizationInfo iqinfo = input->quantization_info().uniform(); -- cgit v1.2.1