aboutsummaryrefslogtreecommitdiff
path: root/src/cpu/kernels/CpuDirectConv2dKernel.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/cpu/kernels/CpuDirectConv2dKernel.h')
-rw-r--r--src/cpu/kernels/CpuDirectConv2dKernel.h22
1 files changed, 11 insertions, 11 deletions
diff --git a/src/cpu/kernels/CpuDirectConv2dKernel.h b/src/cpu/kernels/CpuDirectConv2dKernel.h
index 6ec4d4ee04..b9265dc630 100644
--- a/src/cpu/kernels/CpuDirectConv2dKernel.h
+++ b/src/cpu/kernels/CpuDirectConv2dKernel.h
@@ -36,6 +36,9 @@ namespace kernels
/** Interface for the kernel to perform Direct Convolution Layer. */
class CpuDirectConv2dKernel : public ICpuKernel<CpuDirectConv2dKernel>
{
+private:
+ using DirectConv2dKernel_Ptr = std::add_pointer<void(const Window &, const ITensor *, const ITensor *, ITensor *, const PadStrideInfo &)>::type;
+
public:
CpuDirectConv2dKernel() = default;
ARM_COMPUTE_DISALLOW_COPY_ALLOW_MOVE(CpuDirectConv2dKernel);
@@ -67,19 +70,16 @@ public:
void run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info) override;
const char *name() const override;
-private:
- /* Template function for optimized convolution NHWC */
- template <typename T>
- void convolve_nhwc_optimized(const Window &window, const ITensor *src, const ITensor *weights, ITensor *dst);
+ struct DirectConv2dKernel
+ {
+ const char *name;
+ const DataTypeDataLayoutSelectorPtr is_selected;
+ DirectConv2dKernel_Ptr ukernel;
+ };
- /* Template function for convolution NHWC */
- template <typename T>
- void convolve_nhwc(const Window &window, const ITensor *src, const ITensor *weights, ITensor *dst);
-
- /* Template function for convolution NCHW */
- template <typename T>
- void convolve_nchw(const Window &window, const ITensor *src, const ITensor *weights, ITensor *dst);
+ static const std::vector<DirectConv2dKernel> &get_available_kernels();
+private:
PadStrideInfo _conv_info{};
unsigned int _kernel_size{ 0 };
DataLayout _data_layout{ DataLayout::UNKNOWN };