aboutsummaryrefslogtreecommitdiff
path: root/src/core/CL/kernels
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/CL/kernels')
-rw-r--r--src/core/CL/kernels/CLDirectConvolutionLayerKernel.cpp29
1 files changed, 14 insertions, 15 deletions
diff --git a/src/core/CL/kernels/CLDirectConvolutionLayerKernel.cpp b/src/core/CL/kernels/CLDirectConvolutionLayerKernel.cpp
index 7f9e9d20e1..1f481de921 100644
--- a/src/core/CL/kernels/CLDirectConvolutionLayerKernel.cpp
+++ b/src/core/CL/kernels/CLDirectConvolutionLayerKernel.cpp
@@ -37,32 +37,33 @@
using namespace arm_compute;
-template <unsigned int kernel_size>
-CLDirectConvolutionLayerKernel<kernel_size>::CLDirectConvolutionLayerKernel()
+CLDirectConvolutionLayerKernel::CLDirectConvolutionLayerKernel()
: _input(nullptr), _biases(nullptr), _weights(nullptr), _output(nullptr), _border_size(0), _conv_pad_x(0), _conv_pad_y(0), _conv_stride_x(0), _conv_stride_y(0)
{
}
-template <unsigned int kernel_size>
-BorderSize CLDirectConvolutionLayerKernel<kernel_size>::border_size() const
+BorderSize CLDirectConvolutionLayerKernel::border_size() const
{
return _border_size;
}
-template <unsigned int kernel_size>
-void CLDirectConvolutionLayerKernel<kernel_size>::configure(const ICLTensor *input, const ICLTensor *weights, const ICLTensor *biases, ICLTensor *output, const PadStrideInfo &conv_info)
+void CLDirectConvolutionLayerKernel::configure(const ICLTensor *input, const ICLTensor *weights, const ICLTensor *biases, ICLTensor *output, const PadStrideInfo &conv_info)
{
- static_assert(kernel_size == 3, "Currently only 3x3 direct convolution is supported!");
-
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F32);
+ const unsigned int kernel_size = weights->info()->dimension(0);
+ ARM_COMPUTE_ERROR_ON_MSG(kernel_size != 1 && kernel_size != 3,
+ "Kernel sizes other than 1x1 or 3x3 are not supported");
+ ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16, DataType::F32);
ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, weights, output);
ARM_COMPUTE_ERROR_ON(weights->info()->dimension(2) != input->info()->dimension(2));
ARM_COMPUTE_ERROR_ON(weights->info()->dimension(0) != weights->info()->dimension(1));
ARM_COMPUTE_ERROR_ON(weights->info()->num_dimensions() > 4);
+ ARM_COMPUTE_ERROR_ON_MSG(weights->info()->dimension(0) == 1 && (std::get<0>(conv_info.pad()) || std::get<1>(conv_info.pad())),
+ "Pad > 0 not supported for 1x1 weights");
+ ARM_COMPUTE_ERROR_ON_MSG(weights->info()->dimension(0) == 3 && (std::get<0>(conv_info.pad()) > 1 || std::get<1>(conv_info.pad()) > 1),
+ "Pad > 1 not supported for 3x3 weights");
+ ARM_COMPUTE_ERROR_ON_MSG(std::get<0>(conv_info.stride()) > 3, "Strides larger than 3 not supported.");
ARM_COMPUTE_ERROR_ON_MSG((kernel_size == 3 && std::get<0>(conv_info.stride()) > 2), "Strides larger than 2 not supported in 3x3 direct convolution!");
- ARM_COMPUTE_ERROR_ON(kernel_size != weights->info()->dimension(0));
-
if(biases != nullptr)
{
ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(weights, biases);
@@ -86,6 +87,7 @@ void CLDirectConvolutionLayerKernel<kernel_size>::configure(const ICLTensor *inp
kernel_name << "direct_convolution" << kernel_size << "x" << kernel_size;
options.insert("-DDATA_TYPE=" + get_cl_type_from_data_type(input->info()->data_type()));
+ options.insert("-DDATA_SIZE=" + get_data_size_from_data_type(input->info()->data_type()));
options.emplace("-DSTRIDE_X=" + support::cpp11::to_string(_conv_stride_x));
@@ -130,8 +132,7 @@ void CLDirectConvolutionLayerKernel<kernel_size>::configure(const ICLTensor *inp
ICLKernel::configure(win);
}
-template <unsigned int kernel_size>
-void CLDirectConvolutionLayerKernel<kernel_size>::run(const Window &window, cl::CommandQueue &queue)
+void CLDirectConvolutionLayerKernel::run(const Window &window, cl::CommandQueue &queue)
{
ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(IKernel::window(), window);
@@ -167,5 +168,3 @@ void CLDirectConvolutionLayerKernel<kernel_size>::run(const Window &window, cl::
}
while(window.slide_window_slice_3D(slice) && win_in.slide_window_slice_3D(slice_in));
}
-
-template class arm_compute::CLDirectConvolutionLayerKernel<3>;