aboutsummaryrefslogtreecommitdiff
path: root/src/cpu/kernels/CpuDirectConv3dKernel.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/cpu/kernels/CpuDirectConv3dKernel.h')
-rw-r--r--src/cpu/kernels/CpuDirectConv3dKernel.h22
1 files changed, 17 insertions, 5 deletions
diff --git a/src/cpu/kernels/CpuDirectConv3dKernel.h b/src/cpu/kernels/CpuDirectConv3dKernel.h
index ff3b30f8ae..6ae70bd3b7 100644
--- a/src/cpu/kernels/CpuDirectConv3dKernel.h
+++ b/src/cpu/kernels/CpuDirectConv3dKernel.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -27,6 +27,7 @@
#include "arm_compute/runtime/FunctionDescriptors.h"
#include "src/core/common/Macros.h"
#include "src/cpu/ICpuKernel.h"
+
namespace arm_compute
{
namespace cpu
@@ -34,8 +35,12 @@ namespace cpu
namespace kernels
{
/** Interface for the kernel to perform 3D Direct Convolution Layer. */
-class CpuDirectConv3dKernel : public ICpuKernel
+class CpuDirectConv3dKernel : public NewICpuKernel<CpuDirectConv3dKernel>
{
+private:
+ /* Template function for convolution 3d NDHWC */
+ using DirectConv3dKernelPtr = std::add_pointer<void(const ITensor *, const ITensor *, const ITensor *, ITensor *, const Conv3dInfo &, const Window &)>::type;
+
public:
CpuDirectConv3dKernel() = default;
ARM_COMPUTE_DISALLOW_COPY_ALLOW_MOVE(CpuDirectConv3dKernel);
@@ -71,14 +76,21 @@ public:
void run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info) override;
const char *name() const override;
-private:
- /* Template function for convolution 3d NDHWC */
- using DirectConv3dKernelPtr = std::add_pointer<void(const ITensor *, const ITensor *, const ITensor *, ITensor *, const Conv3dInfo &, const Window &)>::type;
+ struct DirectConv3dKernel
+ {
+ const char *name;
+ const DataTypeISASelectorPtr is_selected;
+ DirectConv3dKernelPtr ukernel;
+ };
+
+ static const std::vector<DirectConv3dKernel> &get_available_kernels();
+private:
Conv3dInfo _conv_info{};
DirectConv3dKernelPtr _run_method{ nullptr };
std::string _name{};
};
+
} // namespace kernels
} // namespace cpu
} // namespace arm_compute