diff options
Diffstat (limited to 'src/cpu/kernels/CpuSoftmaxKernel.h')
-rw-r--r-- | src/cpu/kernels/CpuSoftmaxKernel.h | 10 |
1 files changed, 10 insertions, 0 deletions
diff --git a/src/cpu/kernels/CpuSoftmaxKernel.h b/src/cpu/kernels/CpuSoftmaxKernel.h index df7d3f7d9b..59f43bd1d2 100644 --- a/src/cpu/kernels/CpuSoftmaxKernel.h +++ b/src/cpu/kernels/CpuSoftmaxKernel.h @@ -23,8 +23,10 @@ */ #ifndef ARM_COMPUTE_CPU_SOFTMAX_KERNEL_H #define ARM_COMPUTE_CPU_SOFTMAX_KERNEL_H + #include "src/core/common/Macros.h" #include "src/cpu/ICpuKernel.h" + namespace arm_compute { namespace cpu @@ -53,21 +55,25 @@ public: * @return a status */ static Status validate(const ITensorInfo *src, const ITensorInfo *dst); + // Inherited methods overridden: void run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info) override; const char *name() const override; + struct SoftmaxLogits1DMaxKernel { const char *name; const DataTypeISASelectorPtr is_selected; SoftmaxLogits1DMaxKernelPtr ukernel; }; + static const std::vector<SoftmaxLogits1DMaxKernel> &get_available_kernels(); private: SoftmaxLogits1DMaxKernelPtr _run_method{ nullptr }; std::string _name{}; }; + /** Interface for softmax computation for QASYMM8 with pre-computed max. */ template <bool IS_LOG = false> class CpuLogits1DSoftmaxKernel : public ICpuKernel<CpuLogits1DSoftmaxKernel<IS_LOG>> @@ -78,6 +84,7 @@ private: public: CpuLogits1DSoftmaxKernel() = default; ARM_COMPUTE_DISALLOW_COPY_ALLOW_MOVE(CpuLogits1DSoftmaxKernel); + /** Set the input and output tensors. * * @param[in] src Source tensor info. Data types supported: QASYMM8/QASYMM8_SIGNED/F16/F32. @@ -97,15 +104,18 @@ public: */ static Status validate(const ITensorInfo *src, const ITensorInfo *max, const ITensorInfo *dst, const float beta, const ITensorInfo *tmp); + // Inherited methods overridden: void run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info) override; const char *name() const override; + struct SoftmaxLogits1DKernel { const char *name; const DataTypeISASelectorPtr is_selected; SoftmaxLogits1DKernelPtr ukernel; }; + static const std::vector<SoftmaxLogits1DKernel> &get_available_kernels(); private: |