aboutsummaryrefslogtreecommitdiff
path: root/src/cpu/kernels/CpuElementwiseKernel.h
diff options
context:
space:
mode:
authorDana Zlotnik <dana.zlotnik@arm.com>2022-01-17 09:54:26 +0200
committerDana Zlotnik <dana.zlotnik@arm.com>2022-02-14 12:49:53 +0000
commit6a2df886f32dcf7af4258163b0652f0fab07ecc5 (patch)
tree4ad16670d54d29de96df7cc5b582d52a6012255a /src/cpu/kernels/CpuElementwiseKernel.h
parent69854ba71f91f86c2a1c8a2301e91dcd93030561 (diff)
downloadComputeLibrary-6a2df886f32dcf7af4258163b0652f0fab07ecc5.tar.gz
Add kernel selection UT for submitted kernels
* Softmax kernel * Elementwise unary kernel * Elementwise binary ** This change require some refactor in the kernel cpp and h files Resolves COMPMID-5043 Change-Id: I58979b023ec31d759690847b3f85fc4baefbbf98 Signed-off-by: Dana Zlotnik <dana.zlotnik@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/7033 Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Giorgio Arena <giorgio.arena@arm.com>
Diffstat (limited to 'src/cpu/kernels/CpuElementwiseKernel.h')
-rw-r--r--src/cpu/kernels/CpuElementwiseKernel.h51
1 files changed, 25 insertions, 26 deletions
diff --git a/src/cpu/kernels/CpuElementwiseKernel.h b/src/cpu/kernels/CpuElementwiseKernel.h
index 8cd5d58a96..2785e0a44c 100644
--- a/src/cpu/kernels/CpuElementwiseKernel.h
+++ b/src/cpu/kernels/CpuElementwiseKernel.h
@@ -39,23 +39,29 @@ namespace kernels
* @f[ dst(x,y) = OP(src0(x,y), src1(x,y))@f]
*
*/
-class CpuElementwiseKernel : public ICpuKernel<CpuElementwiseKernel>
+template <class Derived>
+class CpuElementwiseKernel : public ICpuKernel<Derived>
{
+private:
+ using ElementwiseKernelPtr = std::add_pointer<void(const ITensor *, const ITensor *, ITensor *, const Window &)>::type;
+
public:
CpuElementwiseKernel() = default;
ARM_COMPUTE_DISALLOW_COPY_ALLOW_MOVE(CpuElementwiseKernel);
using ElementwiseFunction = void(const ITensor *, const ITensor *, ITensor *, const Window &);
- struct UKernelInfo
- {
- std::string name;
- std::function<ElementwiseFunction> ukernel;
- };
-
// Inherited methods overridden:
void run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info) override;
+
const char *name() const override;
+ struct ElementwiseKernel
+ {
+ const char *name;
+ const ElementwiseDataTypeISASelectorPtr is_selected;
+ ElementwiseKernelPtr ukernel;
+ };
+
protected:
/** Validate the argument passed to the kernel
*
@@ -65,27 +71,12 @@ protected:
*/
static Status validate_arguments_common(const ITensorInfo &src0, const ITensorInfo &src1, const ITensorInfo &dst);
- /** Commmon configure function for element-wise operators with no additional options (e.g. Min, Max, SquaredDiff)
- *
- */
- void configure_common(const ITensorInfo *src0, const ITensorInfo *src1, ITensorInfo *dst);
-
- /** Function to get the micro kernel implementation
- *
- * @param[in] src0 First input tensor information
- * @param[in] src1 Second input tensor information
- * @param[in] dst Output tensor information
- *
- * @return the function instance for the micro kernel
- */
- virtual UKernelInfo get_implementation(const ITensorInfo *src0, const ITensorInfo *src1, ITensorInfo *dst) = 0;
-
protected:
std::function<ElementwiseFunction> _run_method{ nullptr };
std::string _name{};
};
-class CpuArithmeticKernel : public CpuElementwiseKernel
+class CpuArithmeticKernel : public CpuElementwiseKernel<CpuArithmeticKernel>
{
public:
CpuArithmeticKernel() = default;
@@ -107,7 +98,12 @@ public:
*/
static Status validate(ArithmeticOperation op, const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *dst);
+ static const std::vector<CpuElementwiseKernel<CpuArithmeticKernel>::ElementwiseKernel> &get_available_kernels();
+
protected:
+ /** Commmon configure function for element-wise operators with no additional options (e.g. Min, Max, SquaredDiff)
+ */
+ void configure_common(const ITensorInfo *src0, const ITensorInfo *src1, ITensorInfo *dst);
// Inherited methods overridden:
static Status validate_arguments(const ITensorInfo &src0, const ITensorInfo &src1, const ITensorInfo &dst);
@@ -122,7 +118,6 @@ private:
*
* @return the function instance for the micro kernel
*/
- UKernelInfo get_implementation(const ITensorInfo *src0, const ITensorInfo *src1, ITensorInfo *dst) override;
};
class CpuDivisionKernel : public CpuArithmeticKernel
@@ -177,7 +172,7 @@ protected:
static Status validate_arguments(const ITensorInfo &src0, const ITensorInfo &src1, const ITensorInfo &dst);
};
-class CpuComparisonKernel : public CpuElementwiseKernel
+class CpuComparisonKernel : public CpuElementwiseKernel<CpuComparisonKernel>
{
public:
CpuComparisonKernel() = default;
@@ -199,7 +194,12 @@ public:
*/
static Status validate(ComparisonOperation op, const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *dst);
+ static const std::vector<CpuElementwiseKernel<CpuComparisonKernel>::ElementwiseKernel> &get_available_kernels();
+
protected:
+ /** Commmon configure function for element-wise operators with no additional options (e.g. Min, Max, SquaredDiff)
+ */
+ void configure_common(const ITensorInfo *src0, const ITensorInfo *src1, ITensorInfo *dst);
// Inherited methods overridden:
static Status validate_arguments(const ITensorInfo &src0, const ITensorInfo &src1, const ITensorInfo &dst);
@@ -212,7 +212,6 @@ private:
*
* @return the function instance for the micro kernel
*/
- UKernelInfo get_implementation(const ITensorInfo *src0, const ITensorInfo *src1, ITensorInfo *dst) override;
ComparisonOperation _op{};
};