diff options
Diffstat (limited to 'src/core/CL/kernels/CLDirectConvolutionLayerKernel.cpp')
-rw-r--r-- | src/core/CL/kernels/CLDirectConvolutionLayerKernel.cpp | 29 |
1 files changed, 14 insertions, 15 deletions
diff --git a/src/core/CL/kernels/CLDirectConvolutionLayerKernel.cpp b/src/core/CL/kernels/CLDirectConvolutionLayerKernel.cpp index 7f9e9d20e1..1f481de921 100644 --- a/src/core/CL/kernels/CLDirectConvolutionLayerKernel.cpp +++ b/src/core/CL/kernels/CLDirectConvolutionLayerKernel.cpp @@ -37,32 +37,33 @@ using namespace arm_compute; -template <unsigned int kernel_size> -CLDirectConvolutionLayerKernel<kernel_size>::CLDirectConvolutionLayerKernel() +CLDirectConvolutionLayerKernel::CLDirectConvolutionLayerKernel() : _input(nullptr), _biases(nullptr), _weights(nullptr), _output(nullptr), _border_size(0), _conv_pad_x(0), _conv_pad_y(0), _conv_stride_x(0), _conv_stride_y(0) { } -template <unsigned int kernel_size> -BorderSize CLDirectConvolutionLayerKernel<kernel_size>::border_size() const +BorderSize CLDirectConvolutionLayerKernel::border_size() const { return _border_size; } -template <unsigned int kernel_size> -void CLDirectConvolutionLayerKernel<kernel_size>::configure(const ICLTensor *input, const ICLTensor *weights, const ICLTensor *biases, ICLTensor *output, const PadStrideInfo &conv_info) +void CLDirectConvolutionLayerKernel::configure(const ICLTensor *input, const ICLTensor *weights, const ICLTensor *biases, ICLTensor *output, const PadStrideInfo &conv_info) { - static_assert(kernel_size == 3, "Currently only 3x3 direct convolution is supported!"); - - ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F32); + const unsigned int kernel_size = weights->info()->dimension(0); + ARM_COMPUTE_ERROR_ON_MSG(kernel_size != 1 && kernel_size != 3, + "Kernel sizes other than 1x1 or 3x3 are not supported"); + ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16, DataType::F32); ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, weights, output); ARM_COMPUTE_ERROR_ON(weights->info()->dimension(2) != input->info()->dimension(2)); ARM_COMPUTE_ERROR_ON(weights->info()->dimension(0) != weights->info()->dimension(1)); ARM_COMPUTE_ERROR_ON(weights->info()->num_dimensions() > 4); + ARM_COMPUTE_ERROR_ON_MSG(weights->info()->dimension(0) == 1 && (std::get<0>(conv_info.pad()) || std::get<1>(conv_info.pad())), + "Pad > 0 not supported for 1x1 weights"); + ARM_COMPUTE_ERROR_ON_MSG(weights->info()->dimension(0) == 3 && (std::get<0>(conv_info.pad()) > 1 || std::get<1>(conv_info.pad()) > 1), + "Pad > 1 not supported for 3x3 weights"); + ARM_COMPUTE_ERROR_ON_MSG(std::get<0>(conv_info.stride()) > 3, "Strides larger than 3 not supported."); ARM_COMPUTE_ERROR_ON_MSG((kernel_size == 3 && std::get<0>(conv_info.stride()) > 2), "Strides larger than 2 not supported in 3x3 direct convolution!"); - ARM_COMPUTE_ERROR_ON(kernel_size != weights->info()->dimension(0)); - if(biases != nullptr) { ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(weights, biases); @@ -86,6 +87,7 @@ void CLDirectConvolutionLayerKernel<kernel_size>::configure(const ICLTensor *inp kernel_name << "direct_convolution" << kernel_size << "x" << kernel_size; options.insert("-DDATA_TYPE=" + get_cl_type_from_data_type(input->info()->data_type())); + options.insert("-DDATA_SIZE=" + get_data_size_from_data_type(input->info()->data_type())); options.emplace("-DSTRIDE_X=" + support::cpp11::to_string(_conv_stride_x)); @@ -130,8 +132,7 @@ void CLDirectConvolutionLayerKernel<kernel_size>::configure(const ICLTensor *inp ICLKernel::configure(win); } -template <unsigned int kernel_size> -void CLDirectConvolutionLayerKernel<kernel_size>::run(const Window &window, cl::CommandQueue &queue) +void CLDirectConvolutionLayerKernel::run(const Window &window, cl::CommandQueue &queue) { ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this); ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(IKernel::window(), window); @@ -167,5 +168,3 @@ void CLDirectConvolutionLayerKernel<kernel_size>::run(const Window &window, cl:: } while(window.slide_window_slice_3D(slice) && win_in.slide_window_slice_3D(slice_in)); } - -template class arm_compute::CLDirectConvolutionLayerKernel<3>; |