aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/runtime/CL/functions/CLDepthwiseConvolutionLayer.h
diff options
context:
space:
mode:
Diffstat (limited to 'arm_compute/runtime/CL/functions/CLDepthwiseConvolutionLayer.h')
-rw-r--r--arm_compute/runtime/CL/functions/CLDepthwiseConvolutionLayer.h13
1 files changed, 8 insertions, 5 deletions
diff --git a/arm_compute/runtime/CL/functions/CLDepthwiseConvolutionLayer.h b/arm_compute/runtime/CL/functions/CLDepthwiseConvolutionLayer.h
index d6fc8f0fcc..82947bc7e6 100644
--- a/arm_compute/runtime/CL/functions/CLDepthwiseConvolutionLayer.h
+++ b/arm_compute/runtime/CL/functions/CLDepthwiseConvolutionLayer.h
@@ -24,13 +24,15 @@
#ifndef __ARM_COMPUTE_CLDEPTHWISECONVOLUTION_H__
#define __ARM_COMPUTE_CLDEPTHWISECONVOLUTION_H__
-#include "arm_compute/core/CL/kernels/CLDepthwiseConvolutionLayer3x3Kernel.h"
+#include "arm_compute/core/CL/kernels/CLDepthwiseConvolutionLayer3x3NCHWKernel.h"
+#include "arm_compute/core/CL/kernels/CLDepthwiseConvolutionLayer3x3NHWCKernel.h"
#include "arm_compute/core/CL/kernels/CLDepthwiseIm2ColKernel.h"
#include "arm_compute/core/CL/kernels/CLDepthwiseVectorToTensorKernel.h"
#include "arm_compute/core/CL/kernels/CLDepthwiseWeightsReshapeKernel.h"
#include "arm_compute/core/CL/kernels/CLDirectConvolutionLayerOutputStageKernel.h"
#include "arm_compute/core/CL/kernels/CLFillBorderKernel.h"
#include "arm_compute/core/CL/kernels/CLGEMMMatrixVectorMultiplyKernel.h"
+#include "arm_compute/core/CL/kernels/ICLDepthwiseConvolutionLayer3x3Kernel.h"
#include "arm_compute/core/Types.h"
#include "arm_compute/runtime/CL/CLTensor.h"
#include "arm_compute/runtime/IFunction.h"
@@ -39,9 +41,10 @@ namespace arm_compute
{
class ICLTensor;
-/** Basic function to execute a depthwise convolution for kernel size 3x3xC. This function calls the following OpenCL kernels:
+/** Basic function to execute a depthwise convolution for kernel size 3x3xC (when data layout NCHW) or Cx3x3 (when data layout NHWC). This function calls the following OpenCL kernels:
*
- * -# @ref CLDepthwiseConvolutionLayer3x3Kernel
+ * -# @ref CLDepthwiseConvolutionLayer3x3NCHWKernel (if data_layout == NCHW)
+ * -# @ref CLDepthwiseConvolutionLayer3x3NHWCKernel (if data_layout == NHWC)
* -# @ref CLFillBorderKernel (if pad_x or pad_y > 0)
*
*/
@@ -66,8 +69,8 @@ public:
void run() override;
private:
- CLDepthwiseConvolutionLayer3x3Kernel _kernel;
- CLFillBorderKernel _border_handler;
+ std::unique_ptr<ICLDepthwiseConvolutionLayer3x3Kernel> _kernel;
+ CLFillBorderKernel _border_handler;
};
/** Basic function to execute a generic depthwise convolution. This function calls the following OpenCL kernels: