diff options
Diffstat (limited to 'arm_compute/core/CL/kernels/CLDirectConvolutionLayerKernel.h')
-rw-r--r-- | arm_compute/core/CL/kernels/CLDirectConvolutionLayerKernel.h | 21 |
1 files changed, 21 insertions, 0 deletions
diff --git a/arm_compute/core/CL/kernels/CLDirectConvolutionLayerKernel.h b/arm_compute/core/CL/kernels/CLDirectConvolutionLayerKernel.h index 5bf9a5d57f..f1409b6339 100644 --- a/arm_compute/core/CL/kernels/CLDirectConvolutionLayerKernel.h +++ b/arm_compute/core/CL/kernels/CLDirectConvolutionLayerKernel.h @@ -68,6 +68,27 @@ public: * @param[in] conv_info Contains padding and stride information described in @ref PadStrideInfo. */ void configure(const ICLTensor *input, const ICLTensor *weights, const ICLTensor *biases, ICLTensor *output, const PadStrideInfo &conv_info); + /** Set the input, weights, biases and output tensors. + * + * @note: DirectConvolution only works in the following configurations: + * 1x1 convolution with stride_x = 1/2/3, stride_y = 1/2/3 + * 3x3 convolution with stride_x = 1/2, stride_y = 1/2 + * 5x5 convolution with stride_x = 1/2, stride_y = 1/2 + * 9x9 convolution with stride_x = 1/2, stride_y = 1/2, data_layout=NHWC + * + * @param[in] compile_context The compile context to be used. + * @param[in] input The input tensor to convolve. 3 lower dimensions represent a single input [width, height, IFM], + * while every optional dimension from 4 and above represent a batch of inputs. Data types supported: QASYMM8_SIGNED/QASYMM8/F16/F32. + * @param[in] weights Weights tensor. Weights are 4D tensor with dimensions [kernel_x, kernel_y, IFM, OFM]. + * The 3rd dimension must be the same as the input's volume 3rd dimension. + * Data type supported:Same as @p input. + * @param[in] biases Biases tensor. Biases are 1D tensor with dimension [OFM]. + * Data type supported: Should match @p input data type, except for input of QASYMM8 and QASYMM8_SIGNED type where biases should be of S32 type + * @param[out] output Output tensor. + * The 3rd dimensions must be equal to the 4th dimension of the @p kernels tensor. Data types supported: Same as @p input. + * @param[in] conv_info Contains padding and stride information described in @ref PadStrideInfo. + */ + void configure(CLCompileContext &compile_context, const ICLTensor *input, const ICLTensor *weights, const ICLTensor *biases, ICLTensor *output, const PadStrideInfo &conv_info); /** Static function to check if given info will lead to a valid configuration of @ref CLDirectConvolutionLayerKernel * * @param[in] input The input tensor to convolve. 3 lower dimensions represent a single input [width, height, IFM], |