diff options
Diffstat (limited to 'src/cpu/kernels/CpuSoftmaxKernel.h')
-rw-r--r-- | src/cpu/kernels/CpuSoftmaxKernel.h | 23 |
1 files changed, 14 insertions, 9 deletions
diff --git a/src/cpu/kernels/CpuSoftmaxKernel.h b/src/cpu/kernels/CpuSoftmaxKernel.h index 59f43bd1d2..5d288179fd 100644 --- a/src/cpu/kernels/CpuSoftmaxKernel.h +++ b/src/cpu/kernels/CpuSoftmaxKernel.h @@ -57,7 +57,7 @@ public: static Status validate(const ITensorInfo *src, const ITensorInfo *dst); // Inherited methods overridden: - void run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info) override; + void run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info) override; const char *name() const override; struct SoftmaxLogits1DMaxKernel @@ -70,7 +70,7 @@ public: static const std::vector<SoftmaxLogits1DMaxKernel> &get_available_kernels(); private: - SoftmaxLogits1DMaxKernelPtr _run_method{ nullptr }; + SoftmaxLogits1DMaxKernelPtr _run_method{nullptr}; std::string _name{}; }; @@ -79,7 +79,8 @@ template <bool IS_LOG = false> class CpuLogits1DSoftmaxKernel : public ICpuKernel<CpuLogits1DSoftmaxKernel<IS_LOG>> { private: - using SoftmaxLogits1DKernelPtr = std::add_pointer<void(const ITensor *, const ITensor *, void *const, ITensor *, float, bool, const Window &)>::type; + using SoftmaxLogits1DKernelPtr = std::add_pointer<void( + const ITensor *, const ITensor *, void *const, ITensor *, float, bool, const Window &)>::type; public: CpuLogits1DSoftmaxKernel() = default; @@ -95,18 +96,22 @@ public: * * @param tmp Auxiliary tensor info. Must be type F32 and same shape as the input. */ - void configure(const ITensorInfo *src, const ITensorInfo *max, ITensorInfo *dst, const float beta, ITensorInfo *tmp); + void + configure(const ITensorInfo *src, const ITensorInfo *max, ITensorInfo *dst, const float beta, ITensorInfo *tmp); /** Static function to check if given info will lead to a valid configuration * * Similar to CpuLogits1DSoftmaxKernel::configure() * * @return a status */ - static Status validate(const ITensorInfo *src, const ITensorInfo *max, - const ITensorInfo *dst, const float beta, const ITensorInfo *tmp); + static Status validate(const ITensorInfo *src, + const ITensorInfo *max, + const ITensorInfo *dst, + const float beta, + const ITensorInfo *tmp); // Inherited methods overridden: - void run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info) override; + void run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info) override; const char *name() const override; struct SoftmaxLogits1DKernel @@ -119,8 +124,8 @@ public: static const std::vector<SoftmaxLogits1DKernel> &get_available_kernels(); private: - float _beta{ 1.0f }; - SoftmaxLogits1DKernelPtr _run_method{ nullptr }; + float _beta{1.0f}; + SoftmaxLogits1DKernelPtr _run_method{nullptr}; std::string _name{}; }; } // namespace kernels |