diff options
Diffstat (limited to 'src/cpu/kernels/CpuElementwiseKernel.h')
-rw-r--r-- | src/cpu/kernels/CpuElementwiseKernel.h | 51 |
1 files changed, 25 insertions, 26 deletions
diff --git a/src/cpu/kernels/CpuElementwiseKernel.h b/src/cpu/kernels/CpuElementwiseKernel.h index 8cd5d58a96..2785e0a44c 100644 --- a/src/cpu/kernels/CpuElementwiseKernel.h +++ b/src/cpu/kernels/CpuElementwiseKernel.h @@ -39,23 +39,29 @@ namespace kernels * @f[ dst(x,y) = OP(src0(x,y), src1(x,y))@f] * */ -class CpuElementwiseKernel : public ICpuKernel<CpuElementwiseKernel> +template <class Derived> +class CpuElementwiseKernel : public ICpuKernel<Derived> { +private: + using ElementwiseKernelPtr = std::add_pointer<void(const ITensor *, const ITensor *, ITensor *, const Window &)>::type; + public: CpuElementwiseKernel() = default; ARM_COMPUTE_DISALLOW_COPY_ALLOW_MOVE(CpuElementwiseKernel); using ElementwiseFunction = void(const ITensor *, const ITensor *, ITensor *, const Window &); - struct UKernelInfo - { - std::string name; - std::function<ElementwiseFunction> ukernel; - }; - // Inherited methods overridden: void run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info) override; + const char *name() const override; + struct ElementwiseKernel + { + const char *name; + const ElementwiseDataTypeISASelectorPtr is_selected; + ElementwiseKernelPtr ukernel; + }; + protected: /** Validate the argument passed to the kernel * @@ -65,27 +71,12 @@ protected: */ static Status validate_arguments_common(const ITensorInfo &src0, const ITensorInfo &src1, const ITensorInfo &dst); - /** Commmon configure function for element-wise operators with no additional options (e.g. Min, Max, SquaredDiff) - * - */ - void configure_common(const ITensorInfo *src0, const ITensorInfo *src1, ITensorInfo *dst); - - /** Function to get the micro kernel implementation - * - * @param[in] src0 First input tensor information - * @param[in] src1 Second input tensor information - * @param[in] dst Output tensor information - * - * @return the function instance for the micro kernel - */ - virtual UKernelInfo get_implementation(const ITensorInfo *src0, const ITensorInfo *src1, ITensorInfo *dst) = 0; - protected: std::function<ElementwiseFunction> _run_method{ nullptr }; std::string _name{}; }; -class CpuArithmeticKernel : public CpuElementwiseKernel +class CpuArithmeticKernel : public CpuElementwiseKernel<CpuArithmeticKernel> { public: CpuArithmeticKernel() = default; @@ -107,7 +98,12 @@ public: */ static Status validate(ArithmeticOperation op, const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *dst); + static const std::vector<CpuElementwiseKernel<CpuArithmeticKernel>::ElementwiseKernel> &get_available_kernels(); + protected: + /** Commmon configure function for element-wise operators with no additional options (e.g. Min, Max, SquaredDiff) + */ + void configure_common(const ITensorInfo *src0, const ITensorInfo *src1, ITensorInfo *dst); // Inherited methods overridden: static Status validate_arguments(const ITensorInfo &src0, const ITensorInfo &src1, const ITensorInfo &dst); @@ -122,7 +118,6 @@ private: * * @return the function instance for the micro kernel */ - UKernelInfo get_implementation(const ITensorInfo *src0, const ITensorInfo *src1, ITensorInfo *dst) override; }; class CpuDivisionKernel : public CpuArithmeticKernel @@ -177,7 +172,7 @@ protected: static Status validate_arguments(const ITensorInfo &src0, const ITensorInfo &src1, const ITensorInfo &dst); }; -class CpuComparisonKernel : public CpuElementwiseKernel +class CpuComparisonKernel : public CpuElementwiseKernel<CpuComparisonKernel> { public: CpuComparisonKernel() = default; @@ -199,7 +194,12 @@ public: */ static Status validate(ComparisonOperation op, const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *dst); + static const std::vector<CpuElementwiseKernel<CpuComparisonKernel>::ElementwiseKernel> &get_available_kernels(); + protected: + /** Commmon configure function for element-wise operators with no additional options (e.g. Min, Max, SquaredDiff) + */ + void configure_common(const ITensorInfo *src0, const ITensorInfo *src1, ITensorInfo *dst); // Inherited methods overridden: static Status validate_arguments(const ITensorInfo &src0, const ITensorInfo &src1, const ITensorInfo &dst); @@ -212,7 +212,6 @@ private: * * @return the function instance for the micro kernel */ - UKernelInfo get_implementation(const ITensorInfo *src0, const ITensorInfo *src1, ITensorInfo *dst) override; ComparisonOperation _op{}; }; |