aboutsummaryrefslogtreecommitdiff
path: root/src/cpu/kernels/CpuElementwiseKernel.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/cpu/kernels/CpuElementwiseKernel.h')
-rw-r--r--src/cpu/kernels/CpuElementwiseKernel.h51
1 files changed, 25 insertions, 26 deletions
diff --git a/src/cpu/kernels/CpuElementwiseKernel.h b/src/cpu/kernels/CpuElementwiseKernel.h
index 8cd5d58a96..2785e0a44c 100644
--- a/src/cpu/kernels/CpuElementwiseKernel.h
+++ b/src/cpu/kernels/CpuElementwiseKernel.h
@@ -39,23 +39,29 @@ namespace kernels
* @f[ dst(x,y) = OP(src0(x,y), src1(x,y))@f]
*
*/
-class CpuElementwiseKernel : public ICpuKernel<CpuElementwiseKernel>
+template <class Derived>
+class CpuElementwiseKernel : public ICpuKernel<Derived>
{
+private:
+ using ElementwiseKernelPtr = std::add_pointer<void(const ITensor *, const ITensor *, ITensor *, const Window &)>::type;
+
public:
CpuElementwiseKernel() = default;
ARM_COMPUTE_DISALLOW_COPY_ALLOW_MOVE(CpuElementwiseKernel);
using ElementwiseFunction = void(const ITensor *, const ITensor *, ITensor *, const Window &);
- struct UKernelInfo
- {
- std::string name;
- std::function<ElementwiseFunction> ukernel;
- };
-
// Inherited methods overridden:
void run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info) override;
+
const char *name() const override;
+ struct ElementwiseKernel
+ {
+ const char *name;
+ const ElementwiseDataTypeISASelectorPtr is_selected;
+ ElementwiseKernelPtr ukernel;
+ };
+
protected:
/** Validate the argument passed to the kernel
*
@@ -65,27 +71,12 @@ protected:
*/
static Status validate_arguments_common(const ITensorInfo &src0, const ITensorInfo &src1, const ITensorInfo &dst);
- /** Commmon configure function for element-wise operators with no additional options (e.g. Min, Max, SquaredDiff)
- *
- */
- void configure_common(const ITensorInfo *src0, const ITensorInfo *src1, ITensorInfo *dst);
-
- /** Function to get the micro kernel implementation
- *
- * @param[in] src0 First input tensor information
- * @param[in] src1 Second input tensor information
- * @param[in] dst Output tensor information
- *
- * @return the function instance for the micro kernel
- */
- virtual UKernelInfo get_implementation(const ITensorInfo *src0, const ITensorInfo *src1, ITensorInfo *dst) = 0;
-
protected:
std::function<ElementwiseFunction> _run_method{ nullptr };
std::string _name{};
};
-class CpuArithmeticKernel : public CpuElementwiseKernel
+class CpuArithmeticKernel : public CpuElementwiseKernel<CpuArithmeticKernel>
{
public:
CpuArithmeticKernel() = default;
@@ -107,7 +98,12 @@ public:
*/
static Status validate(ArithmeticOperation op, const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *dst);
+ static const std::vector<CpuElementwiseKernel<CpuArithmeticKernel>::ElementwiseKernel> &get_available_kernels();
+
protected:
+ /** Commmon configure function for element-wise operators with no additional options (e.g. Min, Max, SquaredDiff)
+ */
+ void configure_common(const ITensorInfo *src0, const ITensorInfo *src1, ITensorInfo *dst);
// Inherited methods overridden:
static Status validate_arguments(const ITensorInfo &src0, const ITensorInfo &src1, const ITensorInfo &dst);
@@ -122,7 +118,6 @@ private:
*
* @return the function instance for the micro kernel
*/
- UKernelInfo get_implementation(const ITensorInfo *src0, const ITensorInfo *src1, ITensorInfo *dst) override;
};
class CpuDivisionKernel : public CpuArithmeticKernel
@@ -177,7 +172,7 @@ protected:
static Status validate_arguments(const ITensorInfo &src0, const ITensorInfo &src1, const ITensorInfo &dst);
};
-class CpuComparisonKernel : public CpuElementwiseKernel
+class CpuComparisonKernel : public CpuElementwiseKernel<CpuComparisonKernel>
{
public:
CpuComparisonKernel() = default;
@@ -199,7 +194,12 @@ public:
*/
static Status validate(ComparisonOperation op, const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *dst);
+ static const std::vector<CpuElementwiseKernel<CpuComparisonKernel>::ElementwiseKernel> &get_available_kernels();
+
protected:
+ /** Commmon configure function for element-wise operators with no additional options (e.g. Min, Max, SquaredDiff)
+ */
+ void configure_common(const ITensorInfo *src0, const ITensorInfo *src1, ITensorInfo *dst);
// Inherited methods overridden:
static Status validate_arguments(const ITensorInfo &src0, const ITensorInfo &src1, const ITensorInfo &dst);
@@ -212,7 +212,6 @@ private:
*
* @return the function instance for the micro kernel
*/
- UKernelInfo get_implementation(const ITensorInfo *src0, const ITensorInfo *src1, ITensorInfo *dst) override;
ComparisonOperation _op{};
};