aboutsummaryrefslogtreecommitdiff
path: root/src/core/CL/kernels/CLDirectConvolutionLayerKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/CL/kernels/CLDirectConvolutionLayerKernel.cpp')
-rw-r--r--src/core/CL/kernels/CLDirectConvolutionLayerKernel.cpp18
1 files changed, 4 insertions, 14 deletions
diff --git a/src/core/CL/kernels/CLDirectConvolutionLayerKernel.cpp b/src/core/CL/kernels/CLDirectConvolutionLayerKernel.cpp
index 0bf4afd81c..4acbe2dff8 100644
--- a/src/core/CL/kernels/CLDirectConvolutionLayerKernel.cpp
+++ b/src/core/CL/kernels/CLDirectConvolutionLayerKernel.cpp
@@ -60,20 +60,9 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *weights,
"Weights feature map dimension should match the respective input's one");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(weights->num_dimensions() > 4, "Weights can be at most 4 dimensional");
ARM_COMPUTE_RETURN_ERROR_ON_MSG((weights->dimension(width_idx) == 1) && std::get<0>(conv_info.stride()) > 3, "Strides larger than 3 not supported for 1x1 convolution.");
- ARM_COMPUTE_RETURN_ERROR_ON_MSG((weights->dimension(width_idx) == 3 || weights->dimension(width_idx) == 5) && std::get<0>(conv_info.stride()) > 2,
- "Strides larger than 2 not supported for 3x3 convolution.");
-
- const auto data_type = input->data_type();
-
- if(weights->dimension(width_idx) == 9)
- {
- const auto supported_data_layout = is_data_type_quantized(data_type) ? DataLayout::NCHW : DataLayout::NHWC;
- const auto error_message = std::string("Only " + string_from_data_layout(supported_data_layout) + " layout is supported for 9x9 convolution with " + string_from_data_type(
- data_type)
- + " type");
-
- ARM_COMPUTE_RETURN_ERROR_ON_MSG((supported_data_layout != data_layout), error_message.c_str());
- }
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG((weights->dimension(width_idx) == 3 || weights->dimension(width_idx) == 5 || weights->dimension(width_idx) == 9)
+ && std::get<0>(conv_info.stride()) > 2,
+ "Strides larger than 2 not supported for 3x3, 5x5, 9x9 convolution.");
if(biases != nullptr)
{
@@ -99,6 +88,7 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *weights,
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
}
+ const auto data_type = input->data_type();
if(is_data_type_quantized(data_type))
{
const UniformQuantizationInfo iqinfo = input->quantization_info().uniform();