diff options
Diffstat (limited to 'src/cpu/kernels/CpuDepthwiseConv2dNativeKernel.h')
-rw-r--r-- | src/cpu/kernels/CpuDepthwiseConv2dNativeKernel.h | 35 |
1 files changed, 14 insertions, 21 deletions
diff --git a/src/cpu/kernels/CpuDepthwiseConv2dNativeKernel.h b/src/cpu/kernels/CpuDepthwiseConv2dNativeKernel.h index e23a0fac87..95835e6dcf 100644 --- a/src/cpu/kernels/CpuDepthwiseConv2dNativeKernel.h +++ b/src/cpu/kernels/CpuDepthwiseConv2dNativeKernel.h @@ -42,6 +42,10 @@ namespace kernels /** Interface for the kernel to run a depthwise convolution native on a tensor. */ class CpuDepthwiseConv2dNativeKernel : public ICpuKernel<CpuDepthwiseConv2dNativeKernel> { +private: + using DepthwiseConv2dNativeKernelPtr = + std::add_pointer<void(const ITensor *, const ITensor *, const ITensor *, ITensor *, const Window &, bool, const ConvolutionInfo &)>::type; + public: CpuDepthwiseConv2dNativeKernel() = default; ARM_COMPUTE_DISALLOW_COPY_ALLOW_MOVE(CpuDepthwiseConv2dNativeKernel); @@ -71,33 +75,22 @@ public: // Inherited methods overridden: void run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info) override; const char *name() const override; + struct DepthwiseConv2dNativeKernel + { + const char *name; + const DepthwiseConv2dNativeDataTypeISASelectorPtr is_selected; + DepthwiseConv2dNativeKernelPtr ukernel; + }; + static const std::vector<DepthwiseConv2dNativeKernel> &get_available_kernels(); private: - template <typename T> - using FloatEnalber = typename std::enable_if<arm_compute::utils::traits::is_floating_point<T>::value, int>::type; - - template <typename T, typename TW, FloatEnalber<T> = 0> - void run_depthwise(const ITensor *src, const ITensor *weights, const ITensor *bias, ITensor *dst, const Window &window, bool has_biases); - - template <typename T> - using Quantized8bitEnalber = typename std::enable_if < std::is_same<T, uint8_t>::value || std::is_same<T, int8_t>::value, int >::type; - - template <typename T, typename TW, Quantized8bitEnalber<T> = 0> - void run_depthwise(const ITensor *src, const ITensor *weights, const ITensor *bias, ITensor *dst, const Window &window, bool has_biases); - /** Common signature for all the specialised depthwise convolution native functions * * @param[in] window Region on which to execute the kernel. */ - using DepthwiseFunctionPtr = void (CpuDepthwiseConv2dNativeKernel::*)(const ITensor *src, const ITensor *weights, const ITensor *bias, ITensor *dst, const Window &window, bool has_biases); - - DepthwiseFunctionPtr _func{ nullptr }; - PadStrideInfo _conv_info{}; - unsigned int _depth_multiplier{ 1 }; - Size2D _dilation{}; - std::vector<int> _output_multiplier{}; - std::vector<int> _output_shift{}; - bool _has_biases{ false }; + DepthwiseConv2dNativeKernelPtr _func{ nullptr }; + ConvolutionInfo _conv_info{}; + bool _has_biases{ false }; }; } // namespace kernels } // namespace cpu |