diff options
Diffstat (limited to 'src/cpu/kernels/CpuDirectConv2dKernel.h')
-rw-r--r-- | src/cpu/kernels/CpuDirectConv2dKernel.h | 22 |
1 files changed, 11 insertions, 11 deletions
diff --git a/src/cpu/kernels/CpuDirectConv2dKernel.h b/src/cpu/kernels/CpuDirectConv2dKernel.h index 6ec4d4ee04..b9265dc630 100644 --- a/src/cpu/kernels/CpuDirectConv2dKernel.h +++ b/src/cpu/kernels/CpuDirectConv2dKernel.h @@ -36,6 +36,9 @@ namespace kernels /** Interface for the kernel to perform Direct Convolution Layer. */ class CpuDirectConv2dKernel : public ICpuKernel<CpuDirectConv2dKernel> { +private: + using DirectConv2dKernel_Ptr = std::add_pointer<void(const Window &, const ITensor *, const ITensor *, ITensor *, const PadStrideInfo &)>::type; + public: CpuDirectConv2dKernel() = default; ARM_COMPUTE_DISALLOW_COPY_ALLOW_MOVE(CpuDirectConv2dKernel); @@ -67,19 +70,16 @@ public: void run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info) override; const char *name() const override; -private: - /* Template function for optimized convolution NHWC */ - template <typename T> - void convolve_nhwc_optimized(const Window &window, const ITensor *src, const ITensor *weights, ITensor *dst); + struct DirectConv2dKernel + { + const char *name; + const DataTypeDataLayoutSelectorPtr is_selected; + DirectConv2dKernel_Ptr ukernel; + }; - /* Template function for convolution NHWC */ - template <typename T> - void convolve_nhwc(const Window &window, const ITensor *src, const ITensor *weights, ITensor *dst); - - /* Template function for convolution NCHW */ - template <typename T> - void convolve_nchw(const Window &window, const ITensor *src, const ITensor *weights, ITensor *dst); + static const std::vector<DirectConv2dKernel> &get_available_kernels(); +private: PadStrideInfo _conv_info{}; unsigned int _kernel_size{ 0 }; DataLayout _data_layout{ DataLayout::UNKNOWN }; |