diff options
Diffstat (limited to 'arm_compute/runtime/CL/functions/CLDepthwiseConvolutionLayer.h')
-rw-r--r-- | arm_compute/runtime/CL/functions/CLDepthwiseConvolutionLayer.h | 13 |
1 files changed, 8 insertions, 5 deletions
diff --git a/arm_compute/runtime/CL/functions/CLDepthwiseConvolutionLayer.h b/arm_compute/runtime/CL/functions/CLDepthwiseConvolutionLayer.h index d6fc8f0fcc..82947bc7e6 100644 --- a/arm_compute/runtime/CL/functions/CLDepthwiseConvolutionLayer.h +++ b/arm_compute/runtime/CL/functions/CLDepthwiseConvolutionLayer.h @@ -24,13 +24,15 @@ #ifndef __ARM_COMPUTE_CLDEPTHWISECONVOLUTION_H__ #define __ARM_COMPUTE_CLDEPTHWISECONVOLUTION_H__ -#include "arm_compute/core/CL/kernels/CLDepthwiseConvolutionLayer3x3Kernel.h" +#include "arm_compute/core/CL/kernels/CLDepthwiseConvolutionLayer3x3NCHWKernel.h" +#include "arm_compute/core/CL/kernels/CLDepthwiseConvolutionLayer3x3NHWCKernel.h" #include "arm_compute/core/CL/kernels/CLDepthwiseIm2ColKernel.h" #include "arm_compute/core/CL/kernels/CLDepthwiseVectorToTensorKernel.h" #include "arm_compute/core/CL/kernels/CLDepthwiseWeightsReshapeKernel.h" #include "arm_compute/core/CL/kernels/CLDirectConvolutionLayerOutputStageKernel.h" #include "arm_compute/core/CL/kernels/CLFillBorderKernel.h" #include "arm_compute/core/CL/kernels/CLGEMMMatrixVectorMultiplyKernel.h" +#include "arm_compute/core/CL/kernels/ICLDepthwiseConvolutionLayer3x3Kernel.h" #include "arm_compute/core/Types.h" #include "arm_compute/runtime/CL/CLTensor.h" #include "arm_compute/runtime/IFunction.h" @@ -39,9 +41,10 @@ namespace arm_compute { class ICLTensor; -/** Basic function to execute a depthwise convolution for kernel size 3x3xC. This function calls the following OpenCL kernels: +/** Basic function to execute a depthwise convolution for kernel size 3x3xC (when data layout NCHW) or Cx3x3 (when data layout NHWC). This function calls the following OpenCL kernels: * - * -# @ref CLDepthwiseConvolutionLayer3x3Kernel + * -# @ref CLDepthwiseConvolutionLayer3x3NCHWKernel (if data_layout == NCHW) + * -# @ref CLDepthwiseConvolutionLayer3x3NHWCKernel (if data_layout == NHWC) * -# @ref CLFillBorderKernel (if pad_x or pad_y > 0) * */ @@ -66,8 +69,8 @@ public: void run() override; private: - CLDepthwiseConvolutionLayer3x3Kernel _kernel; - CLFillBorderKernel _border_handler; + std::unique_ptr<ICLDepthwiseConvolutionLayer3x3Kernel> _kernel; + CLFillBorderKernel _border_handler; }; /** Basic function to execute a generic depthwise convolution. This function calls the following OpenCL kernels: |