diff options
Diffstat (limited to 'src/cpu/kernels/CpuSoftmaxKernel.h')
-rw-r--r-- | src/cpu/kernels/CpuSoftmaxKernel.h | 36 |
1 files changed, 22 insertions, 14 deletions
diff --git a/src/cpu/kernels/CpuSoftmaxKernel.h b/src/cpu/kernels/CpuSoftmaxKernel.h index f317662620..df7d3f7d9b 100644 --- a/src/cpu/kernels/CpuSoftmaxKernel.h +++ b/src/cpu/kernels/CpuSoftmaxKernel.h @@ -23,10 +23,8 @@ */ #ifndef ARM_COMPUTE_CPU_SOFTMAX_KERNEL_H #define ARM_COMPUTE_CPU_SOFTMAX_KERNEL_H - #include "src/core/common/Macros.h" #include "src/cpu/ICpuKernel.h" - namespace arm_compute { namespace cpu @@ -34,8 +32,11 @@ namespace cpu namespace kernels { /** Interface for the identifying the max value of 1D Logits */ -class CpuLogits1DMaxKernel : public NewICpuKernel<CpuLogits1DMaxKernel> +class CpuLogits1DMaxKernel : public ICpuKernel<CpuLogits1DMaxKernel> { +private: + using SoftmaxLogits1DMaxKernelPtr = std::add_pointer<void(const ITensor *, ITensor *, const Window &)>::type; + public: CpuLogits1DMaxKernel() = default; ARM_COMPUTE_DISALLOW_COPY_ALLOW_MOVE(CpuLogits1DMaxKernel); @@ -52,27 +53,31 @@ public: * @return a status */ static Status validate(const ITensorInfo *src, const ITensorInfo *dst); - // Inherited methods overridden: void run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info) override; const char *name() const override; - -private: - using SoftmaxLogits1DMaxKernelPtr = std::add_pointer<void(const ITensor *, ITensor *, const Window &)>::type; + struct SoftmaxLogits1DMaxKernel + { + const char *name; + const DataTypeISASelectorPtr is_selected; + SoftmaxLogits1DMaxKernelPtr ukernel; + }; + static const std::vector<SoftmaxLogits1DMaxKernel> &get_available_kernels(); private: SoftmaxLogits1DMaxKernelPtr _run_method{ nullptr }; std::string _name{}; }; - /** Interface for softmax computation for QASYMM8 with pre-computed max. */ template <bool IS_LOG = false> -class CpuLogits1DSoftmaxKernel : public NewICpuKernel<CpuLogits1DSoftmaxKernel<IS_LOG>> +class CpuLogits1DSoftmaxKernel : public ICpuKernel<CpuLogits1DSoftmaxKernel<IS_LOG>> { +private: + using SoftmaxLogits1DKernelPtr = std::add_pointer<void(const ITensor *, const ITensor *, void *const, ITensor *, float, bool, const Window &)>::type; + public: CpuLogits1DSoftmaxKernel() = default; ARM_COMPUTE_DISALLOW_COPY_ALLOW_MOVE(CpuLogits1DSoftmaxKernel); - /** Set the input and output tensors. * * @param[in] src Source tensor info. Data types supported: QASYMM8/QASYMM8_SIGNED/F16/F32. @@ -92,13 +97,16 @@ public: */ static Status validate(const ITensorInfo *src, const ITensorInfo *max, const ITensorInfo *dst, const float beta, const ITensorInfo *tmp); - // Inherited methods overridden: void run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info) override; const char *name() const override; - -private: - using SoftmaxLogits1DKernelPtr = std::add_pointer<void(const ITensor *, const ITensor *, void *const, ITensor *, float, bool, const Window &)>::type; + struct SoftmaxLogits1DKernel + { + const char *name; + const DataTypeISASelectorPtr is_selected; + SoftmaxLogits1DKernelPtr ukernel; + }; + static const std::vector<SoftmaxLogits1DKernel> &get_available_kernels(); private: float _beta{ 1.0f }; |