aboutsummaryrefslogtreecommitdiff
path: root/src/cpu/kernels/CpuSoftmaxKernel.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/cpu/kernels/CpuSoftmaxKernel.h')
-rw-r--r--src/cpu/kernels/CpuSoftmaxKernel.h10
1 files changed, 10 insertions, 0 deletions
diff --git a/src/cpu/kernels/CpuSoftmaxKernel.h b/src/cpu/kernels/CpuSoftmaxKernel.h
index df7d3f7d9b..59f43bd1d2 100644
--- a/src/cpu/kernels/CpuSoftmaxKernel.h
+++ b/src/cpu/kernels/CpuSoftmaxKernel.h
@@ -23,8 +23,10 @@
*/
#ifndef ARM_COMPUTE_CPU_SOFTMAX_KERNEL_H
#define ARM_COMPUTE_CPU_SOFTMAX_KERNEL_H
+
#include "src/core/common/Macros.h"
#include "src/cpu/ICpuKernel.h"
+
namespace arm_compute
{
namespace cpu
@@ -53,21 +55,25 @@ public:
* @return a status
*/
static Status validate(const ITensorInfo *src, const ITensorInfo *dst);
+
// Inherited methods overridden:
void run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info) override;
const char *name() const override;
+
struct SoftmaxLogits1DMaxKernel
{
const char *name;
const DataTypeISASelectorPtr is_selected;
SoftmaxLogits1DMaxKernelPtr ukernel;
};
+
static const std::vector<SoftmaxLogits1DMaxKernel> &get_available_kernels();
private:
SoftmaxLogits1DMaxKernelPtr _run_method{ nullptr };
std::string _name{};
};
+
/** Interface for softmax computation for QASYMM8 with pre-computed max. */
template <bool IS_LOG = false>
class CpuLogits1DSoftmaxKernel : public ICpuKernel<CpuLogits1DSoftmaxKernel<IS_LOG>>
@@ -78,6 +84,7 @@ private:
public:
CpuLogits1DSoftmaxKernel() = default;
ARM_COMPUTE_DISALLOW_COPY_ALLOW_MOVE(CpuLogits1DSoftmaxKernel);
+
/** Set the input and output tensors.
*
* @param[in] src Source tensor info. Data types supported: QASYMM8/QASYMM8_SIGNED/F16/F32.
@@ -97,15 +104,18 @@ public:
*/
static Status validate(const ITensorInfo *src, const ITensorInfo *max,
const ITensorInfo *dst, const float beta, const ITensorInfo *tmp);
+
// Inherited methods overridden:
void run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info) override;
const char *name() const override;
+
struct SoftmaxLogits1DKernel
{
const char *name;
const DataTypeISASelectorPtr is_selected;
SoftmaxLogits1DKernelPtr ukernel;
};
+
static const std::vector<SoftmaxLogits1DKernel> &get_available_kernels();
private: