aboutsummaryrefslogtreecommitdiff
path: root/src/cpu/kernels/CpuSoftmaxKernel.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/cpu/kernels/CpuSoftmaxKernel.h')
-rw-r--r--src/cpu/kernels/CpuSoftmaxKernel.h12
1 files changed, 7 insertions, 5 deletions
diff --git a/src/cpu/kernels/CpuSoftmaxKernel.h b/src/cpu/kernels/CpuSoftmaxKernel.h
index 3db1f3d0ef..043ad975d5 100644
--- a/src/cpu/kernels/CpuSoftmaxKernel.h
+++ b/src/cpu/kernels/CpuSoftmaxKernel.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2023 Arm Limited.
+ * Copyright (c) 2017-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -38,7 +38,7 @@ class CpuSoftmaxKernel : public ICpuKernel<CpuSoftmaxKernel>
{
private:
using SoftmaxKernelPtr =
- std::add_pointer<void(const ITensor *, void *const, ITensor *, float, const Window &)>::type;
+ std::add_pointer<void(const ITensor *, void *const, ITensor *, float, int, const Window &)>::type;
public:
CpuSoftmaxKernel() = default;
@@ -49,11 +49,12 @@ public:
* @param[in] src Source tensor info. Data types supported: QASYMM8/QASYMM8_SIGNED/F16/F32.
* @param[out] dst Destination tensor info. Data types supported: same as @p input.
* @param[in] beta A scaling factor for the exponent.
- * @param[in] is_log True if the operation is log-softmax
+ * @param[in] is_log True if the operation is log-softmax.
+ * @param[in] axis The axis along which to perform the softmax operation.
*
* @param tmp Auxiliary tensor info. Must be type F32 and same shape as the input.
*/
- void configure(const ITensorInfo *src, ITensorInfo *dst, float beta, bool is_log, ITensorInfo *tmp);
+ void configure(const ITensorInfo *src, ITensorInfo *dst, float beta, bool is_log, int axis, ITensorInfo *tmp);
/** Static function to check if given info will lead to a valid configuration
*
* Similar to CpuSoftmaxKernel::configure()
@@ -61,7 +62,7 @@ public:
* @return a status
*/
static Status
- validate(const ITensorInfo *src, const ITensorInfo *dst, float beta, bool is_log, const ITensorInfo *tmp);
+ validate(const ITensorInfo *src, const ITensorInfo *dst, float beta, int axis, bool is_log, const ITensorInfo *tmp);
// Inherited methods overridden:
void run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info) override;
@@ -80,6 +81,7 @@ private:
float _beta{1.0f};
SoftmaxKernelPtr _run_method{nullptr};
std::string _name{};
+ int _axis{};
};
} // namespace kernels
} // namespace cpu