aboutsummaryrefslogtreecommitdiff
path: root/src/core/CL/kernels/CLDirectConvolutionLayerKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/CL/kernels/CLDirectConvolutionLayerKernel.cpp')
-rw-r--r--src/core/CL/kernels/CLDirectConvolutionLayerKernel.cpp28
1 files changed, 25 insertions, 3 deletions
diff --git a/src/core/CL/kernels/CLDirectConvolutionLayerKernel.cpp b/src/core/CL/kernels/CLDirectConvolutionLayerKernel.cpp
index 10119d8e8e..7b74a5a98c 100644
--- a/src/core/CL/kernels/CLDirectConvolutionLayerKernel.cpp
+++ b/src/core/CL/kernels/CLDirectConvolutionLayerKernel.cpp
@@ -56,14 +56,23 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *weights,
ARM_COMPUTE_RETURN_ERROR_ON_MSG(weights->dimension(width_idx) != weights->dimension(height_idx), "Weights should have same width and height");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(weights->dimension(width_idx) != 1 && weights->dimension(width_idx) != 3 && weights->dimension(width_idx) != 5 && weights->dimension(width_idx) != 9,
"Kernel sizes other than 1x1, 3x3, 5x5 or 9x9 are not supported");
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(weights->dimension(width_idx) == 9 && input->data_type() == DataType::QASYMM8, "Kernel sizes of 9x9 is not supported for quantized types");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(weights->dimension(channel_idx) != input->dimension(channel_idx),
"Weights feature map dimension should match the respective input's one");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(weights->num_dimensions() > 4, "Weights can be at most 4 dimensional");
ARM_COMPUTE_RETURN_ERROR_ON_MSG((weights->dimension(width_idx) == 1) && std::get<0>(conv_info.stride()) > 3, "Strides larger than 3 not supported for 1x1 convolution.");
ARM_COMPUTE_RETURN_ERROR_ON_MSG((weights->dimension(width_idx) == 3 || weights->dimension(width_idx) == 5) && std::get<0>(conv_info.stride()) > 2,
"Strides larger than 2 not supported for 3x3 convolution.");
- ARM_COMPUTE_RETURN_ERROR_ON_MSG((weights->dimension(width_idx) == 9) && data_layout == DataLayout::NCHW, "Only NHWC layout is supported for 9x9 convolution.");
+
+ const auto data_type = input->data_type();
+
+ if(weights->dimension(width_idx) == 9)
+ {
+ const auto supported_data_layout = is_data_type_quantized(data_type) ? DataLayout::NCHW : DataLayout::NHWC;
+ const auto error_message = std::string("Only " + string_from_data_layout(supported_data_layout) + " layout is supported for 9x9 convolution with " + string_from_data_type(
+ data_type) + " type");
+
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG((supported_data_layout != data_layout), error_message.c_str());
+ }
if(biases != nullptr)
{
@@ -226,6 +235,19 @@ inline void setup_num_elems(unsigned int &num_elems_read_per_iteration_x, unsign
ARM_COMPUTE_ERROR("Invalid convolution stride X");
}
break;
+ case 9:
+ switch(conv_stride_x)
+ {
+ case 1:
+ num_elems_read_per_iteration_x = 16;
+ break;
+ case 2:
+ num_elems_read_per_iteration_x = 24;
+ break;
+ default:
+ ARM_COMPUTE_ERROR("Invalid convolution stride X");
+ }
+ break;
default:
ARM_COMPUTE_ERROR("Invalid direct convolution size");
}
@@ -487,7 +509,7 @@ void CLDirectConvolutionLayerKernel::configure(const ICLTensor *input, const ICL
}
build_options.add_option(std::string("-DDATA_TYPE_PROMOTED=" + get_cl_type_from_data_type(data_type)));
// Create kernel
- _kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel(is_quantized_asymm ? "direct_convolution_1x1_3x3_5x5_quantized" : kernel_name.str(),
+ _kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel(is_quantized_asymm ? "direct_convolution_quantized" : kernel_name.str(),
build_options.options()));
}