diff options
Diffstat (limited to 'src/cpu/kernels/softmax/generic/neon/fp16.cpp')
-rw-r--r-- | src/cpu/kernels/softmax/generic/neon/fp16.cpp | 21 |
1 files changed, 9 insertions, 12 deletions
diff --git a/src/cpu/kernels/softmax/generic/neon/fp16.cpp b/src/cpu/kernels/softmax/generic/neon/fp16.cpp index 2e2adf33e0..db8f881712 100644 --- a/src/cpu/kernels/softmax/generic/neon/fp16.cpp +++ b/src/cpu/kernels/softmax/generic/neon/fp16.cpp @@ -31,21 +31,18 @@ namespace arm_compute { namespace cpu { -void neon_fp16_softmax(const ITensor *in, - const ITensor *max, - void *const tmp, - ITensor *out, - const float beta, - bool is_log, - const Window &window) -{ - return neon_softmax_logits_1d_float<float16_t>(in, max, tmp, out, beta, is_log, window); -} -void neon_fp16_logits(const ITensor *in, ITensor *out, const Window &window) +template <bool IS_LOG> +void neon_fp16_softmax(const ITensor *in, void *const tmp, ITensor *out, const float beta, const Window &window) { - return neon_logits_1d_max<float16_t>(in, out, window); + return neon_softmax_float<float16_t, IS_LOG>(in, tmp, out, beta, window); } + +template void +neon_fp16_softmax<true>(const ITensor *in, void *const tmp, ITensor *out, const float beta, const Window &window); +template void +neon_fp16_softmax<false>(const ITensor *in, void *const tmp, ITensor *out, const float beta, const Window &window); + } // namespace cpu } // namespace arm_compute #endif //defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) |