aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorOmar Al Khatib <omar.alkhatib@arm.com>2024-01-02 14:45:07 +0000
committerOmar Al Khatib <omar.alkhatib@arm.com>2024-03-12 15:45:42 +0000
commit93e743fbe7d52f4c41fcd90762fc38b95be802f7 (patch)
treed0ded85f3cf08f3aabcac68caee4842f3e94da4a
parentd0611c10a08a4e4f78885e76856155a1f05e6720 (diff)
downloadComputeLibrary-93e743fbe7d52f4c41fcd90762fc38b95be802f7.tar.gz
Optimize CpuSoftmaxKernel for axis != 0 and neon kernels
Resolves: COMPMID-6501 Signed-off-by: Omar Al Khatib <omar.alkhatib@arm.com> Change-Id: I0abd3cbb5f861301f407c443988fb7efaa205b5d Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/11056 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Gunes Bayir <gunes.bayir@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Benchmark: Arm Jenkins <bsgcomp@arm.com>
-rw-r--r--docs/user_guide/release_version_and_change_log.dox2
-rw-r--r--src/cpu/kernels/CpuSoftmaxKernel.cpp62
-rw-r--r--src/cpu/kernels/CpuSoftmaxKernel.h12
-rw-r--r--src/cpu/kernels/softmax/generic/neon/fp16.cpp22
-rw-r--r--src/cpu/kernels/softmax/generic/neon/fp32.cpp22
-rw-r--r--src/cpu/kernels/softmax/generic/neon/impl.cpp353
-rw-r--r--src/cpu/kernels/softmax/generic/neon/impl.h197
-rw-r--r--src/cpu/kernels/softmax/generic/neon/qasymm8.cpp22
-rw-r--r--src/cpu/kernels/softmax/generic/neon/qasymm8_signed.cpp17
-rw-r--r--src/cpu/kernels/softmax/list.h4
-rw-r--r--src/cpu/operators/CpuSoftmax.cpp90
-rw-r--r--src/cpu/operators/CpuSoftmax.h9
12 files changed, 666 insertions, 146 deletions
diff --git a/docs/user_guide/release_version_and_change_log.dox b/docs/user_guide/release_version_and_change_log.dox
index bc7d2cb126..2d46737e96 100644
--- a/docs/user_guide/release_version_and_change_log.dox
+++ b/docs/user_guide/release_version_and_change_log.dox
@@ -44,6 +44,8 @@ If there is more than one release in a month then an extra sequential number is
v24.04 Public major release
- Optimize start-up time of @ref NEConvolutionLayer for some input configurations where GeMM is selected as the convolution algorithm
- Optimize @ref NEConvolutionLayer for input tensor size > 1e7 bytes and weight tensor height > 7
+ - Performance optimizations:
+ - Optimize @ref NESoftmaxLayer for axis != 0 by natively supporting higher axes up to axis 3.
v24.02.1 Public patch release
- Fix performance regression in fixed-format kernels
diff --git a/src/cpu/kernels/CpuSoftmaxKernel.cpp b/src/cpu/kernels/CpuSoftmaxKernel.cpp
index 68bc397acf..54ff858eeb 100644
--- a/src/cpu/kernels/CpuSoftmaxKernel.cpp
+++ b/src/cpu/kernels/CpuSoftmaxKernel.cpp
@@ -81,7 +81,7 @@ static const std::vector<typename CpuSoftmaxKernel::SoftmaxKernel> available_ker
};
Status validate_arguments_softmax(
- const ITensorInfo &src, const ITensorInfo &dst, float beta, const ITensorInfo &tmp, bool is_log)
+ const ITensorInfo &src, const ITensorInfo &dst, float beta, int axis, const ITensorInfo &tmp, bool is_log)
{
ARM_COMPUTE_UNUSED(beta);
// Check input
@@ -89,6 +89,8 @@ Status validate_arguments_softmax(
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&src, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED,
DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON(axis < 0 || axis > 3);
+
const bool is_quantized_asymmetric = is_data_type_quantized_asymmetric(src.data_type());
// Check output if configured
@@ -124,10 +126,13 @@ const std::vector<typename CpuSoftmaxKernel::SoftmaxKernel> &CpuSoftmaxKernel::g
return available_kernels;
}
-void CpuSoftmaxKernel::configure(const ITensorInfo *src, ITensorInfo *dst, float beta, bool is_log, ITensorInfo *tmp)
+void CpuSoftmaxKernel::configure(
+ const ITensorInfo *src, ITensorInfo *dst, float beta, bool is_log, int axis, ITensorInfo *tmp)
{
+ _axis = axis;
+
ARM_COMPUTE_ERROR_ON_NULLPTR(src, dst, tmp);
- ARM_COMPUTE_ERROR_THROW_ON(validate_arguments_softmax(*src, *dst, beta, *tmp, is_log));
+ ARM_COMPUTE_ERROR_THROW_ON(validate_arguments_softmax(*src, *dst, beta, axis, *tmp, is_log));
// Configure kernel window
const bool is_quantized_asymmetric = is_data_type_quantized_asymmetric(src->data_type());
@@ -154,25 +159,40 @@ void CpuSoftmaxKernel::configure(const ITensorInfo *src, ITensorInfo *dst, float
_run_method = uk->ukernel;
_name = kernel_name.append("/").append(uk->name);
- Window win = calculate_max_window(*dst, Steps());
+ Window win;
+
+ int vec_size = 16 / dst->element_size();
- /// TODO: Check dimensions > 0 for holes only. For this, we need
- /// a utility function checking if there are holes after some dimension.
- if (!has_holes(*dst, dst->num_dimensions() - 1))
+ if (_axis == 0)
+ {
+ win = calculate_max_window(*dst, Steps());
+
+ /// TODO:Check dimensions > 0 for holes only. For this, we need
+ /// a utility function checking if there are holes after some dimension.
+ if (!has_holes(*dst, dst->num_dimensions() - 1))
+ {
+ win = win.collapse(win, Window::DimY);
+ }
+ }
+ else if (_axis > 0 && _axis <= 3)
{
- win = win.collapse(win, Window::DimY);
+ win = calculate_max_window(*dst, Steps(vec_size));
+ }
+ else
+ {
+ ARM_COMPUTE_ERROR("Invalid axis");
}
- win.set(Window::DimX, Window::Dimension(0, 1, 1)); // First dimension is the reduction axis
+ win.set(_axis, Window::Dimension(0, 1, 1));
ICpuKernel<CpuSoftmaxKernel>::configure(win);
}
Status CpuSoftmaxKernel::validate(
- const ITensorInfo *src, const ITensorInfo *dst, float beta, bool is_log, const ITensorInfo *tmp)
+ const ITensorInfo *src, const ITensorInfo *dst, float beta, int axis, bool is_log, const ITensorInfo *tmp)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(src, dst, tmp);
- ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments_softmax(*src, *dst, beta, *tmp, is_log));
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments_softmax(*src, *dst, beta, axis, *tmp, is_log));
return Status{};
}
@@ -188,19 +208,25 @@ void CpuSoftmaxKernel::run_op(ITensorPack &tensors, const Window &window, const
if (is_data_type_quantized_asymmetric(src->info()->data_type()))
{
- auto tmp = tensors.get_tensor(TensorType::ACL_DST_1);
-
- const unsigned int num_elems_processed_per_iteration = src->info()->valid_region().shape.x();
+ auto tmp = tensors.get_tensor(TensorType::ACL_DST_1);
+ unsigned int num_elems_processed_per_iteration;
+ if (_axis == 0)
+ {
+ num_elems_processed_per_iteration = src->info()->valid_region().shape[_axis];
+ }
+ else
+ {
+ //16 QASYMM8/QASYMM8_SIGNED elements can fit into the 16-byte vectors.
+ num_elems_processed_per_iteration = 16;
+ }
const unsigned int tmp_size_for_thread = tmp->info()->element_size() * num_elems_processed_per_iteration;
- ARM_COMPUTE_ERROR_ON(tmp->info()->total_size() < (info.num_threads * tmp_size_for_thread));
-
void *tmp_for_thread = tmp->buffer() + (info.thread_id * tmp_size_for_thread);
- _run_method(src, tmp_for_thread, dst, _beta, window);
+ _run_method(src, tmp_for_thread, dst, _beta, _axis, window);
}
else
{
- _run_method(src, nullptr, dst, _beta, window);
+ _run_method(src, nullptr, dst, _beta, _axis, window);
}
}
diff --git a/src/cpu/kernels/CpuSoftmaxKernel.h b/src/cpu/kernels/CpuSoftmaxKernel.h
index 3db1f3d0ef..043ad975d5 100644
--- a/src/cpu/kernels/CpuSoftmaxKernel.h
+++ b/src/cpu/kernels/CpuSoftmaxKernel.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2023 Arm Limited.
+ * Copyright (c) 2017-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -38,7 +38,7 @@ class CpuSoftmaxKernel : public ICpuKernel<CpuSoftmaxKernel>
{
private:
using SoftmaxKernelPtr =
- std::add_pointer<void(const ITensor *, void *const, ITensor *, float, const Window &)>::type;
+ std::add_pointer<void(const ITensor *, void *const, ITensor *, float, int, const Window &)>::type;
public:
CpuSoftmaxKernel() = default;
@@ -49,11 +49,12 @@ public:
* @param[in] src Source tensor info. Data types supported: QASYMM8/QASYMM8_SIGNED/F16/F32.
* @param[out] dst Destination tensor info. Data types supported: same as @p input.
* @param[in] beta A scaling factor for the exponent.
- * @param[in] is_log True if the operation is log-softmax
+ * @param[in] is_log True if the operation is log-softmax.
+ * @param[in] axis The axis along which to perform the softmax operation.
*
* @param tmp Auxiliary tensor info. Must be type F32 and same shape as the input.
*/
- void configure(const ITensorInfo *src, ITensorInfo *dst, float beta, bool is_log, ITensorInfo *tmp);
+ void configure(const ITensorInfo *src, ITensorInfo *dst, float beta, bool is_log, int axis, ITensorInfo *tmp);
/** Static function to check if given info will lead to a valid configuration
*
* Similar to CpuSoftmaxKernel::configure()
@@ -61,7 +62,7 @@ public:
* @return a status
*/
static Status
- validate(const ITensorInfo *src, const ITensorInfo *dst, float beta, bool is_log, const ITensorInfo *tmp);
+ validate(const ITensorInfo *src, const ITensorInfo *dst, float beta, int axis, bool is_log, const ITensorInfo *tmp);
// Inherited methods overridden:
void run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info) override;
@@ -80,6 +81,7 @@ private:
float _beta{1.0f};
SoftmaxKernelPtr _run_method{nullptr};
std::string _name{};
+ int _axis{};
};
} // namespace kernels
} // namespace cpu
diff --git a/src/cpu/kernels/softmax/generic/neon/fp16.cpp b/src/cpu/kernels/softmax/generic/neon/fp16.cpp
index db8f881712..da62d2d614 100644
--- a/src/cpu/kernels/softmax/generic/neon/fp16.cpp
+++ b/src/cpu/kernels/softmax/generic/neon/fp16.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021-2023 Arm Limited.
+ * Copyright (c) 2021-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -33,15 +33,23 @@ namespace cpu
{
template <bool IS_LOG>
-void neon_fp16_softmax(const ITensor *in, void *const tmp, ITensor *out, const float beta, const Window &window)
+void neon_fp16_softmax(
+ const ITensor *in, void *const tmp, ITensor *out, const float beta, int axis, const Window &window)
{
- return neon_softmax_float<float16_t, IS_LOG>(in, tmp, out, beta, window);
+ if (axis == 0)
+ {
+ return neon_softmax_x_float<float16_t, IS_LOG>(in, tmp, out, beta, axis, window);
+ }
+ else
+ {
+ return neon_softmax_non_x_float<float16_t, IS_LOG>(in, tmp, out, beta, axis, window);
+ }
}
-template void
-neon_fp16_softmax<true>(const ITensor *in, void *const tmp, ITensor *out, const float beta, const Window &window);
-template void
-neon_fp16_softmax<false>(const ITensor *in, void *const tmp, ITensor *out, const float beta, const Window &window);
+template void neon_fp16_softmax<true>(
+ const ITensor *in, void *const tmp, ITensor *out, const float beta, int axis, const Window &window);
+template void neon_fp16_softmax<false>(
+ const ITensor *in, void *const tmp, ITensor *out, const float beta, int axis, const Window &window);
} // namespace cpu
} // namespace arm_compute
diff --git a/src/cpu/kernels/softmax/generic/neon/fp32.cpp b/src/cpu/kernels/softmax/generic/neon/fp32.cpp
index c281d1bf31..0701620636 100644
--- a/src/cpu/kernels/softmax/generic/neon/fp32.cpp
+++ b/src/cpu/kernels/softmax/generic/neon/fp32.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021-2023 Arm Limited.
+ * Copyright (c) 2021-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -31,15 +31,23 @@ namespace cpu
{
template <bool IS_LOG>
-void neon_fp32_softmax(const ITensor *in, void *const tmp, ITensor *out, const float beta, const Window &window)
+void neon_fp32_softmax(
+ const ITensor *in, void *const tmp, ITensor *out, const float beta, int axis, const Window &window)
{
- return neon_softmax_float<float, IS_LOG>(in, tmp, out, beta, window);
+ if (axis == 0)
+ {
+ return neon_softmax_x_float<float, IS_LOG>(in, tmp, out, beta, axis, window);
+ }
+ else
+ {
+ return neon_softmax_non_x_float<float, IS_LOG>(in, tmp, out, beta, axis, window);
+ }
}
-template void
-neon_fp32_softmax<true>(const ITensor *in, void *const tmp, ITensor *out, const float beta, const Window &window);
-template void
-neon_fp32_softmax<false>(const ITensor *in, void *const tmp, ITensor *out, const float beta, const Window &window);
+template void neon_fp32_softmax<true>(
+ const ITensor *in, void *const tmp, ITensor *out, const float beta, int axis, const Window &window);
+template void neon_fp32_softmax<false>(
+ const ITensor *in, void *const tmp, ITensor *out, const float beta, int axis, const Window &window);
} // namespace cpu
} // namespace arm_compute
diff --git a/src/cpu/kernels/softmax/generic/neon/impl.cpp b/src/cpu/kernels/softmax/generic/neon/impl.cpp
index 487f6ae051..31baf8a9df 100644
--- a/src/cpu/kernels/softmax/generic/neon/impl.cpp
+++ b/src/cpu/kernels/softmax/generic/neon/impl.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021-2023 Arm Limited.
+ * Copyright (c) 2021-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -30,8 +30,11 @@ namespace arm_compute
namespace cpu
{
template <typename T, bool IS_LOG>
-void neon_softmax_quantized(const ITensor *in, void *const tmp, ITensor *out, float beta, const Window &window)
+void neon_softmax_x_quantized(
+ const ITensor *in, void *const tmp, ITensor *out, float beta, int axis, const Window &window)
{
+ ARM_COMPUTE_UNUSED(axis);
+
static_assert(std::is_same<T, qasymm8_t>::value || std::is_same<T, qasymm8_signed_t>::value,
"quantized type should be either qasymm8_t or qasymm8_signed_t.");
@@ -248,16 +251,346 @@ void neon_softmax_quantized(const ITensor *in, void *const tmp, ITensor *out, fl
in_it, out_it);
}
-template void neon_softmax_quantized<qasymm8_signed_t, true>(
- const ITensor *in, void *const tmp, ITensor *out, float beta, const Window &window);
+template <typename T, bool IS_LOG>
+void neon_softmax_non_x_quantized(
+ const ITensor *in, void *const tmp, ITensor *out, float beta, int axis, const Window &window)
+{
+ static_assert(std::is_same<T, qasymm8_t>::value || std::is_same<T, qasymm8_signed_t>::value,
+ "quantized type should be either qasymm8_t or qasymm8_signed_t.");
+
+ const float scale_beta = -beta * in->info()->quantization_info().uniform().scale;
+ const float32x4_t scale_beta_vec = vdupq_n_f32(scale_beta);
+
+ Iterator in_it(in, window);
+ Iterator out_it(out, window);
+
+ /** SIMD vector tag type. */
+ using ExactTagType = typename wrapper::traits::neon_bitvector_tag_t<T, wrapper::traits::BitWidth::W128>;
+
+ constexpr int vec_size = 16;
+ const ITensorInfo *in_info = in->info();
+ const ITensorInfo *out_info = out->info();
+ const int x_width = in_info->valid_region().shape.x();
+ const int in_axis_stride = in_info->strides_in_bytes()[axis];
+ const int out_axis_stride = out_info->strides_in_bytes()[axis];
+ const int tmp_axis_stride = in_axis_stride;
+ const int axis_width = in_info->dimension(axis);
+ const int end_actual = std::min(window[0].end(), x_width);
+
+ execute_window_loop(
+ window,
+ [&](const Coordinates &winCoords)
+ {
+ const bool vector_exceeds_bounds = ((winCoords[0] + vec_size) > end_actual);
+
+ int num_remaining = (end_actual - winCoords[0]);
+ int num_remaining_full = num_remaining / 4;
+ int num_remaining_partial = num_remaining % 4;
+
+ /* Get pointers */
+ const uint8_t *in_ptr = in_it.ptr();
+ uint8_t *out_ptr = out_it.ptr();
+ uint8_t *tmp_ptr = reinterpret_cast<uint8_t *>(tmp);
+
+ auto vec_max = wrapper::vdup_n(support::cpp11::lowest<T>(), ExactTagType{});
+
+ /* Compute Max */
+ {
+ if (!vector_exceeds_bounds)
+ {
+ int i = 0;
+ for (; i < axis_width; ++i)
+ {
+ const auto current_value =
+ wrapper::vloadq((i * in_axis_stride) + reinterpret_cast<const T *>(in_ptr));
+ vec_max = wrapper::vmax(vec_max, current_value);
+ }
+ }
+ else
+ {
+ int i = 0;
+ for (; i < axis_width; ++i)
+ {
+ const T *const base_ptr_in = ((i * in_axis_stride) + reinterpret_cast<const T *>(in_ptr));
+ int j = 0;
+ for (; j < num_remaining; ++j)
+ {
+ const T current_value = *(base_ptr_in + j);
+ vec_max[j] = std::max(vec_max[j], current_value);
+ }
+ }
+ }
+ } // Compute Max
+
+ float32x4x4_t vec_sum_transformed = {
+ vdupq_n_f32(0.f),
+ vdupq_n_f32(0.f),
+ vdupq_n_f32(0.f),
+ vdupq_n_f32(0.f),
+ };
+
+ /* Compute exponentials and sum */
+ {
+ /* Init sum to zero */
+ float32x4x4_t vec_sum = vec_sum_transformed;
+
+ auto vec_elements = wrapper::vdup_n(static_cast<T>(0), ExactTagType{});
+
+ float32x4x4_t vec_elements_flt;
+
+ if (!vector_exceeds_bounds)
+ {
+ int i = 0;
+ for (; i < axis_width; ++i)
+ {
+ vec_elements = wrapper::vloadq((i * in_axis_stride) + reinterpret_cast<const T *>(in_ptr));
+ vec_elements = wrapper::vqsub(vec_max, vec_elements);
+ vec_elements_flt = convert_int_to_float<float32x4x4_t>(vec_elements);
+
+ if (IS_LOG)
+ {
+ vec_elements_flt.val[0] = vmulq_f32(vec_elements_flt.val[0], scale_beta_vec);
+ vec_elements_flt.val[1] = vmulq_f32(vec_elements_flt.val[1], scale_beta_vec);
+ vec_elements_flt.val[2] = vmulq_f32(vec_elements_flt.val[2], scale_beta_vec);
+ vec_elements_flt.val[3] = vmulq_f32(vec_elements_flt.val[3], scale_beta_vec);
+ vec_sum.val[0] = vaddq_f32(vec_sum.val[0], vexpq_f32(vec_elements_flt.val[0]));
+ vec_sum.val[1] = vaddq_f32(vec_sum.val[1], vexpq_f32(vec_elements_flt.val[1]));
+ vec_sum.val[2] = vaddq_f32(vec_sum.val[2], vexpq_f32(vec_elements_flt.val[2]));
+ vec_sum.val[3] = vaddq_f32(vec_sum.val[3], vexpq_f32(vec_elements_flt.val[3]));
+ }
+ else
+ {
+ vec_elements_flt.val[0] = vexpq_f32(vmulq_f32(vec_elements_flt.val[0], scale_beta_vec));
+ vec_elements_flt.val[1] = vexpq_f32(vmulq_f32(vec_elements_flt.val[1], scale_beta_vec));
+ vec_elements_flt.val[2] = vexpq_f32(vmulq_f32(vec_elements_flt.val[2], scale_beta_vec));
+ vec_elements_flt.val[3] = vexpq_f32(vmulq_f32(vec_elements_flt.val[3], scale_beta_vec));
+ vec_sum.val[0] = vaddq_f32(vec_sum.val[0], vec_elements_flt.val[0]);
+ vec_sum.val[1] = vaddq_f32(vec_sum.val[1], vec_elements_flt.val[1]);
+ vec_sum.val[2] = vaddq_f32(vec_sum.val[2], vec_elements_flt.val[2]);
+ vec_sum.val[3] = vaddq_f32(vec_sum.val[3], vec_elements_flt.val[3]);
+ }
+ vst4q_f32((i * tmp_axis_stride) + reinterpret_cast<float *>(tmp_ptr), vec_elements_flt);
+ }
+
+ auto vec_256 = wrapper::vdup_n(static_cast<float32_t>(256.f), ExactTagType{});
+ if (!IS_LOG)
+ {
+ vec_sum_transformed.val[0] = wrapper::vdiv(vec_256, vec_sum.val[0]);
+ vec_sum_transformed.val[1] = wrapper::vdiv(vec_256, vec_sum.val[1]);
+ vec_sum_transformed.val[2] = wrapper::vdiv(vec_256, vec_sum.val[2]);
+ vec_sum_transformed.val[3] = wrapper::vdiv(vec_256, vec_sum.val[3]);
+ }
+ else
+ {
+ vec_sum_transformed.val[0] = wrapper::vlog(vec_sum.val[0]);
+ vec_sum_transformed.val[1] = wrapper::vlog(vec_sum.val[1]);
+ vec_sum_transformed.val[2] = wrapper::vlog(vec_sum.val[2]);
+ vec_sum_transformed.val[3] = wrapper::vlog(vec_sum.val[3]);
+ }
+ }
+ else
+ {
+ int i = 0;
+ for (; i < axis_width; ++i)
+ {
+ const T *const base_ptr_in = (i * in_axis_stride) + reinterpret_cast<const T *>(in_ptr);
+ auto vec_elements = wrapper::vdup_n(static_cast<T>(0), ExactTagType{});
+ //vec_els is functionally redundant but is needed as a workaround for a toolchain bug.
+ std::vector<T> vec_els(16);
+
+ for (int k = 0; k < num_remaining_full; ++k)
+ {
+ for (int j = 0; j < 4; ++j)
+ {
+ vec_els[k * 4 + j] = *(base_ptr_in + (4 * k + j));
+ }
+ }
+ for (int j = 0; j < num_remaining_partial; ++j)
+ {
+ vec_els[num_remaining_full * 4 + j] = *(base_ptr_in + (4 * num_remaining_full + j));
+ }
+ for (int q = 0; q < 16; q++)
+ {
+ vec_elements[q] = vec_els[q];
+ }
+ vec_elements = wrapper::vqsub(vec_max, vec_elements);
+ float32x4x4_t vec_elements_flt = convert_int_to_float<float32x4x4_t>(vec_elements);
+
+ if (IS_LOG)
+ {
+ vec_elements_flt.val[0] = vmulq_f32(vec_elements_flt.val[0], scale_beta_vec);
+ vec_elements_flt.val[1] = vmulq_f32(vec_elements_flt.val[1], scale_beta_vec);
+ vec_elements_flt.val[2] = vmulq_f32(vec_elements_flt.val[2], scale_beta_vec);
+ vec_elements_flt.val[3] = vmulq_f32(vec_elements_flt.val[3], scale_beta_vec);
+ vec_sum.val[0] = vaddq_f32(vec_sum.val[0], vexpq_f32(vec_elements_flt.val[0]));
+ vec_sum.val[1] = vaddq_f32(vec_sum.val[1], vexpq_f32(vec_elements_flt.val[1]));
+ vec_sum.val[2] = vaddq_f32(vec_sum.val[2], vexpq_f32(vec_elements_flt.val[2]));
+ vec_sum.val[3] = vaddq_f32(vec_sum.val[3], vexpq_f32(vec_elements_flt.val[3]));
+ }
+ else
+ {
+ vec_elements_flt.val[0] = vexpq_f32(vmulq_f32(vec_elements_flt.val[0], scale_beta_vec));
+ vec_elements_flt.val[1] = vexpq_f32(vmulq_f32(vec_elements_flt.val[1], scale_beta_vec));
+ vec_elements_flt.val[2] = vexpq_f32(vmulq_f32(vec_elements_flt.val[2], scale_beta_vec));
+ vec_elements_flt.val[3] = vexpq_f32(vmulq_f32(vec_elements_flt.val[3], scale_beta_vec));
+ vec_sum.val[0] = vaddq_f32(vec_sum.val[0], vec_elements_flt.val[0]);
+ vec_sum.val[1] = vaddq_f32(vec_sum.val[1], vec_elements_flt.val[1]);
+ vec_sum.val[2] = vaddq_f32(vec_sum.val[2], vec_elements_flt.val[2]);
+ vec_sum.val[3] = vaddq_f32(vec_sum.val[3], vec_elements_flt.val[3]);
+ }
+
+ float *const base_ptr_tmp = (i * tmp_axis_stride) + reinterpret_cast<float *>(tmp_ptr);
+ for (int k = 0; k < num_remaining_full; ++k)
+ {
+ for (int j = 0; j < 4; ++j)
+ {
+ *(base_ptr_tmp + (4 * k + j)) = vec_elements_flt.val[k][j];
+ }
+ }
+
+ for (int j = 0; j < num_remaining_partial; ++j)
+ {
+ *(base_ptr_tmp + (4 * num_remaining_full + j)) =
+ vec_elements_flt.val[num_remaining_full][j];
+ }
+ }
+
+ auto vec_256 = wrapper::vdup_n(static_cast<float32_t>(256), ExactTagType{});
+ if (!IS_LOG)
+ {
+ vec_sum_transformed.val[0] = wrapper::vdiv(vec_256, vec_sum.val[0]);
+ vec_sum_transformed.val[1] = wrapper::vdiv(vec_256, vec_sum.val[1]);
+ vec_sum_transformed.val[2] = wrapper::vdiv(vec_256, vec_sum.val[2]);
+ vec_sum_transformed.val[3] = wrapper::vdiv(vec_256, vec_sum.val[3]);
+ }
+ else
+ {
+ vec_sum_transformed.val[0] = wrapper::vlog(vec_sum.val[0]);
+ vec_sum_transformed.val[1] = wrapper::vlog(vec_sum.val[1]);
+ vec_sum_transformed.val[2] = wrapper::vlog(vec_sum.val[2]);
+ vec_sum_transformed.val[3] = wrapper::vlog(vec_sum.val[3]);
+ }
+ }
+ } // Compute exponentials and sum
+
+ /* Normalize exponentials */
+ {
+ constexpr bool is_qasymm8_signed = std::is_same<T, qasymm8_signed_t>::value;
+ if (!vector_exceeds_bounds)
+ {
+ int i = 0;
+ for (; i < axis_width; ++i)
+ {
+ using int_vec_type = wrapper::traits::neon_vector_t<T, 16>;
+ float32x4x4_t vec_in = vld4q_f32((i * tmp_axis_stride) + reinterpret_cast<float *>(tmp_ptr));
+
+ int_vec_type normalized_value{};
+
+ if (IS_LOG)
+ {
+ const float32x4x4_t sub = {
+ vsubq_f32(vec_in.val[0], vec_sum_transformed.val[0]),
+ vsubq_f32(vec_in.val[1], vec_sum_transformed.val[1]),
+ vsubq_f32(vec_in.val[2], vec_sum_transformed.val[2]),
+ vsubq_f32(vec_in.val[3], vec_sum_transformed.val[3]),
+ };
+ normalized_value = convert_float_to_int<float32x4x4_t, int_vec_type>(sub);
+ }
+ else
+ {
+ float32x4x4_t mul = {
+ vmulq_f32(vec_in.val[0], vec_sum_transformed.val[0]),
+ vmulq_f32(vec_in.val[1], vec_sum_transformed.val[1]),
+ vmulq_f32(vec_in.val[2], vec_sum_transformed.val[2]),
+ vmulq_f32(vec_in.val[3], vec_sum_transformed.val[3]),
+ };
+
+ if (is_qasymm8_signed)
+ {
+ const auto offset_vec = wrapper::vdup_n(128.f, wrapper::traits::vector_128_tag{});
+ mul.val[0] = wrapper::vsub(mul.val[0], offset_vec);
+ mul.val[1] = wrapper::vsub(mul.val[1], offset_vec);
+ mul.val[2] = wrapper::vsub(mul.val[2], offset_vec);
+ mul.val[3] = wrapper::vsub(mul.val[3], offset_vec);
+ }
+
+ normalized_value = convert_float_to_int<float32x4x4_t, int_vec_type>(mul);
+ }
+ wrapper::vstore((i * out_axis_stride) + reinterpret_cast<T *>(out_ptr), normalized_value);
+ }
+ }
+ else
+ {
+ int i = 0;
+ for (; i < axis_width; ++i)
+ {
+ T *const base_ptr_out = (i * out_axis_stride) + reinterpret_cast<T *>(out_ptr);
+ float *const base_ptr_tmp = (i * tmp_axis_stride) + reinterpret_cast<float *>(tmp_ptr);
+ if (IS_LOG)
+ {
+ for (int k = 0; k < num_remaining_full; ++k)
+ {
+ for (int j = 0; j < 4; ++j)
+ {
+ *(base_ptr_out + (4 * k + j)) = utils::cast::saturate_cast<T>(
+ (*(base_ptr_tmp + (4 * k + j)) - vec_sum_transformed.val[k][j]));
+ }
+ }
+ for (int j = 0; j < num_remaining_partial; ++j)
+ {
+ *(base_ptr_out + (4 * num_remaining_full + j)) =
+ utils::cast::saturate_cast<T>(*(base_ptr_tmp + (4 * num_remaining_full + j)) -
+ vec_sum_transformed.val[num_remaining_full][j]);
+ }
+ }
+ else
+ {
+ for (int k = 0; k < num_remaining_full; ++k)
+ {
+ for (int j = 0; j < 4; ++j)
+ {
+ *(base_ptr_out + (4 * k + j)) = utils::cast::saturate_cast<T>(
+ *(base_ptr_tmp + (4 * k + j)) * vec_sum_transformed.val[k][j] -
+ (is_qasymm8_signed ? 128.f : 0));
+ }
+ }
+ for (int j = 0; j < num_remaining_partial; ++j)
+ {
+ *(base_ptr_out + (4 * num_remaining_full + j)) =
+ utils::cast::saturate_cast<T>(*(base_ptr_tmp + (4 * num_remaining_full + j)) *
+ vec_sum_transformed.val[num_remaining_full][j] -
+ (is_qasymm8_signed ? 128.f : 0));
+ }
+ }
+ }
+ }
+ } // Normalize exponentials
+ },
+ in_it, out_it);
+}
+
+template void neon_softmax_x_quantized<qasymm8_signed_t, true>(
+ const ITensor *in, void *const tmp, ITensor *out, float beta, int axis, const Window &window);
+
+template void neon_softmax_x_quantized<qasymm8_signed_t, false>(
+ const ITensor *in, void *const tmp, ITensor *out, float beta, int axis, const Window &window);
+
+template void neon_softmax_x_quantized<qasymm8_t, true>(
+ const ITensor *in, void *const tmp, ITensor *out, float beta, int axis, const Window &window);
+
+template void neon_softmax_x_quantized<qasymm8_t, false>(
+ const ITensor *in, void *const tmp, ITensor *out, float beta, int axis, const Window &window);
+
+template void neon_softmax_non_x_quantized<qasymm8_signed_t, true>(
+ const ITensor *in, void *const tmp, ITensor *out, float beta, int axis, const Window &window);
-template void neon_softmax_quantized<qasymm8_signed_t, false>(
- const ITensor *in, void *const tmp, ITensor *out, float beta, const Window &window);
+template void neon_softmax_non_x_quantized<qasymm8_signed_t, false>(
+ const ITensor *in, void *const tmp, ITensor *out, float beta, int axis, const Window &window);
-template void neon_softmax_quantized<qasymm8_t, true>(
- const ITensor *in, void *const tmp, ITensor *out, float beta, const Window &window);
+template void neon_softmax_non_x_quantized<qasymm8_t, true>(
+ const ITensor *in, void *const tmp, ITensor *out, float beta, int axis, const Window &window);
-template void neon_softmax_quantized<qasymm8_t, false>(
- const ITensor *in, void *const tmp, ITensor *out, float beta, const Window &window);
+template void neon_softmax_non_x_quantized<qasymm8_t, false>(
+ const ITensor *in, void *const tmp, ITensor *out, float beta, int axis, const Window &window);
} // namespace cpu
} // namespace arm_compute
diff --git a/src/cpu/kernels/softmax/generic/neon/impl.h b/src/cpu/kernels/softmax/generic/neon/impl.h
index 60380cd233..e417271d0e 100644
--- a/src/cpu/kernels/softmax/generic/neon/impl.h
+++ b/src/cpu/kernels/softmax/generic/neon/impl.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021-2023 Arm Limited.
+ * Copyright (c) 2021-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -62,8 +62,9 @@ inline float wrapper_vaddv(const float32x4_t &a, int sum_stages)
// The template implementation for float data types is stored in the header file because
// we need all fp16 instantiated code to live in fp16.cpp files.
template <typename T, bool IS_LOG>
-void neon_softmax_float(const ITensor *in, void *const tmp, ITensor *out, float beta, const Window &window)
+void neon_softmax_x_float(const ITensor *in, void *const tmp, ITensor *out, float beta, int axis, const Window &window)
{
+ ARM_COMPUTE_UNUSED(axis);
ARM_COMPUTE_UNUSED(tmp);
const int input_width = in->info()->valid_region().shape.x();
@@ -228,9 +229,199 @@ void neon_softmax_float(const ITensor *in, void *const tmp, ITensor *out, float
},
in_it, out_it);
}
+template <typename T, bool IS_LOG>
+void neon_softmax_non_x_float(
+ const ITensor *in, void *const tmp, ITensor *out, float beta, int axis, const Window &window)
+{
+ ARM_COMPUTE_UNUSED(tmp);
+
+ Iterator in_it(in, window);
+ Iterator out_it(out, window);
+
+ /** SIMD vector tag type. */
+ using ExactTagType = typename wrapper::traits::neon_bitvector_tag_t<T, wrapper::traits::BitWidth::W128>;
+
+ const auto beta_vec = wrapper::vdup_n(static_cast<T>(beta), ExactTagType{});
+ constexpr int vec_size = 16 / sizeof(T);
+ const ITensorInfo *in_info = in->info();
+ const ITensorInfo *out_info = out->info();
+ const int x_width = in_info->valid_region().shape.x();
+ const unsigned int in_axis_stride = in_info->strides_in_bytes()[axis];
+ const unsigned int out_axis_stride = out_info->strides_in_bytes()[axis];
+ const int axis_width = in_info->dimension(axis);
+
+ execute_window_loop(
+ window,
+ [&](const Coordinates &winCoords)
+ {
+ const bool vector_exceeds_bounds = (winCoords[0] + vec_size) > x_width;
+
+ /* Get pointers */
+ const uint8_t *in_ptr = in_it.ptr();
+ uint8_t *out_ptr = out_it.ptr();
+
+ // Init max value
+ auto vec_max = wrapper::vdup_n(support::cpp11::lowest<T>(), ExactTagType{});
+
+ /* Compute Max */
+ {
+ if (!vector_exceeds_bounds)
+ {
+ int i = 0;
+ for (; i < axis_width; ++i)
+ {
+ const auto current_value =
+ wrapper::vloadq(reinterpret_cast<const T *>((i * in_axis_stride) + in_ptr));
+ vec_max = wrapper::vmax(vec_max, current_value);
+ }
+ }
+ else
+ {
+ int i = 0;
+ for (; i < axis_width; ++i)
+ {
+ const T *const base_ptr_in = reinterpret_cast<const T *>((i * in_axis_stride) + in_ptr);
+ int j = 0;
+ for (; j < (x_width - winCoords[0]); ++j)
+ {
+ const auto current_value = *(base_ptr_in + j);
+ vec_max[j] = std::max(vec_max[j], current_value);
+ }
+ }
+ }
+ } // compute max
+
+ auto vec_sum_transformed = wrapper::vdup_n(static_cast<T>(0), ExactTagType{});
+
+ auto vec_elements = wrapper::vdup_n(static_cast<T>(0), ExactTagType{});
+ /* Init sum to zero */
+ auto vec_sum = wrapper::vdup_n(static_cast<T>(0), ExactTagType{});
+
+ /* Compute exponentials and sum */
+ {
+ if (!vector_exceeds_bounds)
+ {
+ const auto vec_one = wrapper::vdup_n(static_cast<T>(1), ExactTagType{});
+ /* Loop over row and compute exponentials and sum */
+ int i = 0;
+ for (; i < axis_width; ++i)
+ {
+ vec_elements = wrapper::vloadq(reinterpret_cast<const T *>((i * in_axis_stride) + in_ptr));
+ vec_elements = wrapper::vsub(vec_elements, vec_max);
+ if (IS_LOG)
+ {
+ vec_elements = wrapper::vmul(vec_elements, beta_vec);
+ vec_sum = wrapper::vadd(vec_sum, wrapper::vexpq(vec_elements));
+ }
+ else
+ {
+ vec_elements = wrapper::vexpq(wrapper::vmul(vec_elements, beta_vec));
+ vec_sum = wrapper::vadd(vec_sum, vec_elements);
+ }
+
+ wrapper::vstore(reinterpret_cast<T *>((i * out_axis_stride) + out_ptr), vec_elements);
+ }
+
+ if (!IS_LOG)
+ {
+ vec_sum_transformed = wrapper::vdiv(vec_one, vec_sum);
+ }
+ else
+ {
+ vec_sum_transformed = wrapper::vlog(vec_sum);
+ }
+ }
+ else
+ {
+ int i = 0;
+ for (; i < axis_width; ++i)
+ {
+ const T *const base_ptr_in = reinterpret_cast<const T *>((i * in_axis_stride) + in_ptr);
+ T *const base_ptr_out = reinterpret_cast<T *>((i * out_axis_stride) + out_ptr);
+ int j = 0;
+ for (; j < (x_width - winCoords[0]); ++j)
+ {
+ vec_elements[j] = *(base_ptr_in + j);
+ vec_elements[j] -= vec_max[j];
+ if (IS_LOG)
+ {
+ vec_elements[j] *= beta;
+ vec_sum[j] += std::exp(vec_elements[j]);
+ }
+ else
+ {
+ vec_elements[j] = std::exp(vec_elements[j] * beta);
+ vec_sum[j] += vec_elements[j];
+ }
+ *(base_ptr_out + j) = vec_elements[j];
+ }
+ }
+ int j = 0;
+ for (; j < (x_width - winCoords[0]); ++j)
+ {
+ if (!IS_LOG)
+ {
+ vec_sum_transformed[j] = 1 / vec_sum[j];
+ }
+ else
+ {
+ vec_sum_transformed[j] = std::log(vec_sum[j]);
+ }
+ }
+ }
+ } // Compute exponentials and sum
+
+ /* Normalize exponentials */
+ {
+ if (!vector_exceeds_bounds)
+ {
+ /* Loop over row and compute softmax */
+ int i = 0;
+ for (; i < axis_width; ++i)
+ {
+ T *const base_ptr_out = reinterpret_cast<T *>((i * out_axis_stride) + out_ptr);
+ auto vec_in = wrapper::vloadq(base_ptr_out);
+ if (IS_LOG)
+ {
+ wrapper::vstore(base_ptr_out, wrapper::vsub(vec_in, vec_sum_transformed));
+ }
+ else
+ {
+ wrapper::vstore(base_ptr_out, wrapper::vmul(vec_in, vec_sum_transformed));
+ }
+ }
+ }
+ else
+ {
+ int i = 0;
+ for (; i < axis_width; ++i)
+ {
+ T *const base_ptr_out = reinterpret_cast<T *>((i * out_axis_stride) + out_ptr);
+ int j = 0;
+ for (; j < (x_width - winCoords[0]); ++j)
+ {
+ if (IS_LOG)
+ {
+ *(base_ptr_out + j) -= vec_sum_transformed[j];
+ }
+ else
+ {
+ *(base_ptr_out + j) *= vec_sum_transformed[j];
+ }
+ }
+ }
+ }
+ } // Normalize exponentials
+ },
+ in_it, out_it);
+}
+template <typename T, bool IS_LOG>
+void neon_softmax_x_quantized(
+ const ITensor *in, void *const tmp, ITensor *out, float beta, int axis, const Window &window);
template <typename T, bool IS_LOG>
-void neon_softmax_quantized(const ITensor *in, void *const tmp, ITensor *out, float beta, const Window &window);
+void neon_softmax_non_x_quantized(
+ const ITensor *in, void *const tmp, ITensor *out, float beta, int axis, const Window &window);
} // namespace cpu
} // namespace arm_compute
diff --git a/src/cpu/kernels/softmax/generic/neon/qasymm8.cpp b/src/cpu/kernels/softmax/generic/neon/qasymm8.cpp
index 9589ebcd7c..d39240bb38 100644
--- a/src/cpu/kernels/softmax/generic/neon/qasymm8.cpp
+++ b/src/cpu/kernels/softmax/generic/neon/qasymm8.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021-2023 Arm Limited.
+ * Copyright (c) 2021-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -30,15 +30,23 @@ namespace arm_compute
namespace cpu
{
template <bool IS_LOG>
-void neon_qasymm8_softmax(const ITensor *in, void *const tmp, ITensor *out, const float beta, const Window &window)
+void neon_qasymm8_softmax(
+ const ITensor *in, void *const tmp, ITensor *out, const float beta, int axis, const Window &window)
{
- return neon_softmax_quantized<qasymm8_t, IS_LOG>(in, tmp, out, beta, window);
+ if (axis == 0)
+ {
+ return neon_softmax_x_quantized<qasymm8_t, IS_LOG>(in, tmp, out, beta, axis, window);
+ }
+ else
+ {
+ return neon_softmax_non_x_quantized<qasymm8_t, IS_LOG>(in, tmp, out, beta, axis, window);
+ }
}
-template void
-neon_qasymm8_softmax<true>(const ITensor *in, void *const tmp, ITensor *out, const float beta, const Window &window);
-template void
-neon_qasymm8_softmax<false>(const ITensor *in, void *const tmp, ITensor *out, const float beta, const Window &window);
+template void neon_qasymm8_softmax<true>(
+ const ITensor *in, void *const tmp, ITensor *out, const float beta, int axis, const Window &window);
+template void neon_qasymm8_softmax<false>(
+ const ITensor *in, void *const tmp, ITensor *out, const float beta, int axis, const Window &window);
} // namespace cpu
} // namespace arm_compute
diff --git a/src/cpu/kernels/softmax/generic/neon/qasymm8_signed.cpp b/src/cpu/kernels/softmax/generic/neon/qasymm8_signed.cpp
index 0bf6b2859a..26fd5dbfa0 100644
--- a/src/cpu/kernels/softmax/generic/neon/qasymm8_signed.cpp
+++ b/src/cpu/kernels/softmax/generic/neon/qasymm8_signed.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021-2023 Arm Limited.
+ * Copyright (c) 2021-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -31,15 +31,22 @@ namespace cpu
{
template <bool IS_LOG>
void neon_qasymm8_signed_softmax(
- const ITensor *in, void *const tmp, ITensor *out, const float beta, const Window &window)
+ const ITensor *in, void *const tmp, ITensor *out, const float beta, int axis, const Window &window)
{
- return neon_softmax_quantized<qasymm8_signed_t, IS_LOG>(in, tmp, out, beta, window);
+ if (axis == 0)
+ {
+ return neon_softmax_x_quantized<qasymm8_signed_t, IS_LOG>(in, tmp, out, beta, axis, window);
+ }
+ else
+ {
+ return neon_softmax_non_x_quantized<qasymm8_signed_t, IS_LOG>(in, tmp, out, beta, axis, window);
+ }
}
template void neon_qasymm8_signed_softmax<true>(
- const ITensor *in, void *const tmp, ITensor *out, const float beta, const Window &window);
+ const ITensor *in, void *const tmp, ITensor *out, const float beta, int axis, const Window &window);
template void neon_qasymm8_signed_softmax<false>(
- const ITensor *in, void *const tmp, ITensor *out, const float beta, const Window &window);
+ const ITensor *in, void *const tmp, ITensor *out, const float beta, int axis, const Window &window);
} // namespace cpu
} // namespace arm_compute
diff --git a/src/cpu/kernels/softmax/list.h b/src/cpu/kernels/softmax/list.h
index c143f6659d..f9295ebbcc 100644
--- a/src/cpu/kernels/softmax/list.h
+++ b/src/cpu/kernels/softmax/list.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021-2023 Arm Limited.
+ * Copyright (c) 2021-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -30,7 +30,7 @@ namespace cpu
{
#define DECLARE_SOFTMAX_KERNEL(func_name) \
template <bool IS_LOG> \
- void func_name(const ITensor *in, void *const tmp, ITensor *out, const float beta, const Window &window)
+ void func_name(const ITensor *in, void *const tmp, ITensor *out, const float beta, int axis, const Window &window)
DECLARE_SOFTMAX_KERNEL(neon_fp32_softmax);
DECLARE_SOFTMAX_KERNEL(neon_fp16_softmax);
diff --git a/src/cpu/operators/CpuSoftmax.cpp b/src/cpu/operators/CpuSoftmax.cpp
index ae14381ad9..fecee7d765 100644
--- a/src/cpu/operators/CpuSoftmax.cpp
+++ b/src/cpu/operators/CpuSoftmax.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021, 2023 Arm Limited.
+ * Copyright (c) 2021, 2023-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -41,15 +41,7 @@ namespace arm_compute
{
namespace cpu
{
-CpuSoftmaxGeneric::CpuSoftmaxGeneric()
- : _permute_input(),
- _permute_output(),
- _softmax_kernel(),
- _tmp(),
- _input_permuted(),
- _output_permuted(),
- _needs_permute(false),
- _aux_mem(InternalTensorIdx::COUNT)
+CpuSoftmaxGeneric::CpuSoftmaxGeneric() : _softmax_kernel(), _tmp(), _aux_mem(InternalTensorIdx::COUNT)
{
}
@@ -63,17 +55,9 @@ void CpuSoftmaxGeneric::configure(const ITensorInfo *src, ITensorInfo *dst, floa
const unsigned int actual_axis =
static_cast<unsigned int>(wrap_around(axis, static_cast<int32_t>(src->num_dimensions())));
- _needs_permute = actual_axis > 0;
+ _axis = actual_axis;
- if (_needs_permute)
- {
- _permute_input.configure(src, &_input_permuted,
- softmax_helpers::get_permutation_vector_from_softmax_axis(actual_axis));
- }
-
- // We want to deal with a 2D input. Either it is the permuted version of the original input (4D case)
- // or it is the original input case (2D case)
- const ITensorInfo *tmp_input = (_needs_permute ? &_input_permuted : src);
+ const ITensorInfo *tmp_input = src;
TensorInfo tensor_info_tmp;
if (is_data_type_quantized_asymmetric(src->data_type()))
@@ -88,20 +72,10 @@ void CpuSoftmaxGeneric::configure(const ITensorInfo *src, ITensorInfo *dst, floa
// Configure kernels
auto sm = std::make_unique<kernels::CpuSoftmaxKernel>();
- if (_needs_permute)
- {
- // The normalization kernel stores the result in a permuted output tensor
- sm->configure(tmp_input, &_output_permuted, beta, is_log, &_tmp);
- // Re-permute the permuted output into the requested (4D) output
- _permute_output.configure(&_output_permuted, dst,
- softmax_helpers::get_permutation_vector_from_softmax_axis(actual_axis));
- }
- else
- {
- // Softmax 2D case
- sm->configure(tmp_input, dst, beta, is_log, &_tmp);
- }
+ // Softmax 2D case
+ sm->configure(tmp_input, dst, beta, is_log, actual_axis, &_tmp);
+
_softmax_kernel = std::move(sm);
if (_tmp.total_size() > 0)
@@ -109,11 +83,6 @@ void CpuSoftmaxGeneric::configure(const ITensorInfo *src, ITensorInfo *dst, floa
_aux_mem[InternalTensorIdx::TMP] =
MemoryInfo(offset_int_vec(InternalTensorIdx::TMP), MemoryLifetime::Temporary, _tmp.total_size());
}
-
- _aux_mem[InternalTensorIdx::PERMUTED_SRC] = MemoryInfo(offset_int_vec(InternalTensorIdx::PERMUTED_SRC),
- MemoryLifetime::Temporary, _input_permuted.total_size());
- _aux_mem[InternalTensorIdx::PERMUTED_DST] = MemoryInfo(offset_int_vec(InternalTensorIdx::PERMUTED_DST),
- MemoryLifetime::Temporary, _output_permuted.total_size());
}
Status
@@ -133,25 +102,11 @@ CpuSoftmaxGeneric::validate(const ITensorInfo *src, const ITensorInfo *dst, floa
{
tensor_info_tmp = src->clone()->set_data_type(DataType::F32).set_is_resizable(true);
}
-
const unsigned int actual_axis =
static_cast<unsigned int>(wrap_around(axis, static_cast<int32_t>(src->num_dimensions())));
- const bool needs_permute = actual_axis > 0;
-
- if (needs_permute)
- {
- const PermutationVector permutation_vector =
- softmax_helpers::get_permutation_vector_from_softmax_axis(actual_axis);
- const TensorShape permuted_shape =
- misc::shape_calculator::compute_permutation_output_shape(*src, permutation_vector);
- TensorInfo input_permuted(src->clone()->set_tensor_shape(permuted_shape));
- ARM_COMPUTE_RETURN_ON_ERROR(CpuPermute::validate(src, &input_permuted, permutation_vector));
- TensorInfo output_permuted(dst->clone()->set_tensor_shape(permuted_shape));
- ARM_COMPUTE_RETURN_ON_ERROR(CpuPermute::validate(&output_permuted, dst, permutation_vector));
- }
-
- ARM_COMPUTE_RETURN_ON_ERROR(kernels::CpuSoftmaxKernel::validate(src, dst, beta, is_log, &tensor_info_tmp));
+ ARM_COMPUTE_RETURN_ON_ERROR(
+ kernels::CpuSoftmaxKernel::validate(src, dst, beta, actual_axis, is_log, &tensor_info_tmp));
return Status{};
}
@@ -165,34 +120,17 @@ void CpuSoftmaxGeneric::run(ITensorPack &tensors)
CpuAuxTensorHandler tmp(offset_int_vec(InternalTensorIdx::TMP), _tmp, tensors, true);
- CpuAuxTensorHandler input_permuted(offset_int_vec(InternalTensorIdx::PERMUTED_SRC), _input_permuted, tensors, true);
- CpuAuxTensorHandler output_permuted(offset_int_vec(InternalTensorIdx::PERMUTED_DST), _output_permuted, tensors,
- true);
-
ITensorPack softmax_pack;
- if (_needs_permute)
- {
- ITensorPack permute_in_pack = {{TensorType::ACL_SRC, src}, {TensorType::ACL_DST, input_permuted.get()}};
- _permute_input.run(permute_in_pack);
+ softmax_pack = {{TensorType::ACL_SRC_0, src}, {TensorType::ACL_DST_0, dst}, {TensorType::ACL_DST_1, tmp.get()}};
- softmax_pack = {{TensorType::ACL_SRC_0, input_permuted.get()},
- {TensorType::ACL_DST_0, output_permuted.get()},
- {TensorType::ACL_DST_1, tmp.get()}};
- }
- else
+ if (_axis == 0)
{
- softmax_pack = {{TensorType::ACL_SRC_0, src}, {TensorType::ACL_DST_0, dst}, {TensorType::ACL_DST_1, tmp.get()}};
+ NEScheduler::get().schedule_op(_softmax_kernel.get(), Window::DimY, _softmax_kernel->window(), softmax_pack);
}
-
- NEScheduler::get().schedule_op(_softmax_kernel.get(), Window::DimY, _softmax_kernel->window(), softmax_pack);
-
- if (_needs_permute)
+ else
{
- ITensorPack permute_out_pack;
- permute_out_pack.add_tensor(TensorType::ACL_SRC, output_permuted.get());
- permute_out_pack.add_tensor(TensorType::ACL_DST, dst);
- _permute_output.run(permute_out_pack);
+ NEScheduler::get().schedule_op(_softmax_kernel.get(), Window::DimX, _softmax_kernel->window(), softmax_pack);
}
}
diff --git a/src/cpu/operators/CpuSoftmax.h b/src/cpu/operators/CpuSoftmax.h
index 47020e9b7c..6ba3476eff 100644
--- a/src/cpu/operators/CpuSoftmax.h
+++ b/src/cpu/operators/CpuSoftmax.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021-2023 Arm Limited.
+ * Copyright (c) 2021-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -89,16 +89,13 @@ private:
COUNT
};
- CpuPermute _permute_input;
- CpuPermute _permute_output;
std::unique_ptr<ICPPKernel> _softmax_kernel;
TensorInfo _tmp;
- TensorInfo _input_permuted;
- TensorInfo _output_permuted;
- bool _needs_permute;
experimental::MemoryRequirements _aux_mem{};
+
+ unsigned int _axis = 0;
};
} // namespace cpu