aboutsummaryrefslogtreecommitdiff
path: root/src/cpu/operators/CpuSoftmax.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/cpu/operators/CpuSoftmax.cpp')
-rw-r--r--src/cpu/operators/CpuSoftmax.cpp91
1 files changed, 30 insertions, 61 deletions
diff --git a/src/cpu/operators/CpuSoftmax.cpp b/src/cpu/operators/CpuSoftmax.cpp
index e55d7f903e..ae14381ad9 100644
--- a/src/cpu/operators/CpuSoftmax.cpp
+++ b/src/cpu/operators/CpuSoftmax.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021, 2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -41,13 +41,10 @@ namespace arm_compute
{
namespace cpu
{
-template <bool IS_LOG>
-CpuSoftmaxGeneric<IS_LOG>::CpuSoftmaxGeneric()
+CpuSoftmaxGeneric::CpuSoftmaxGeneric()
: _permute_input(),
_permute_output(),
- _max_kernel(),
_softmax_kernel(),
- _max(),
_tmp(),
_input_permuted(),
_output_permuted(),
@@ -56,8 +53,7 @@ CpuSoftmaxGeneric<IS_LOG>::CpuSoftmaxGeneric()
{
}
-template <bool IS_LOG>
-void CpuSoftmaxGeneric<IS_LOG>::configure(const ITensorInfo *src, ITensorInfo *dst, float beta, int32_t axis)
+void CpuSoftmaxGeneric::configure(const ITensorInfo *src, ITensorInfo *dst, float beta, int32_t axis, bool is_log)
{
// Perform validation step
ARM_COMPUTE_ERROR_ON_NULLPTR(src, dst);
@@ -79,29 +75,23 @@ void CpuSoftmaxGeneric<IS_LOG>::configure(const ITensorInfo *src, ITensorInfo *d
// or it is the original input case (2D case)
const ITensorInfo *tmp_input = (_needs_permute ? &_input_permuted : src);
- // Create intermediate tensors shapes
- TensorShape max_sum_shape = tmp_input->tensor_shape();
- max_sum_shape.set(0, 1);
- const TensorInfo input_info = tmp_input->clone()->reset_padding().set_is_resizable(true);
- DataType tmp_data_type =
- is_data_type_quantized_asymmetric(tmp_input->data_type()) ? DataType::F32 : tmp_input->data_type();
- TensorInfo tensor_info_tmp(input_info.clone()->set_data_type(tmp_data_type));
- TensorInfo max_info(tmp_input->clone()->set_tensor_shape(max_sum_shape));
+ TensorInfo tensor_info_tmp;
+ if (is_data_type_quantized_asymmetric(src->data_type()))
+ {
+ // Create intermediate tensors shapes
+ const TensorInfo input_info = tmp_input->clone()->reset_padding().set_is_resizable(true);
+ tensor_info_tmp = input_info.clone()->set_data_type(DataType::F32);
+ }
// Init intermediate tensors
- _max = TensorInfo(max_info);
_tmp = TensorInfo(tensor_info_tmp);
// Configure kernels
- auto mk = std::make_unique<kernels::CpuLogits1DMaxKernel>();
- mk->configure(tmp_input, &_max);
- _max_kernel = std::move(mk);
-
- auto sm = std::make_unique<kernels::CpuLogits1DSoftmaxKernel<IS_LOG>>();
+ auto sm = std::make_unique<kernels::CpuSoftmaxKernel>();
if (_needs_permute)
{
// The normalization kernel stores the result in a permuted output tensor
- sm->configure(tmp_input, &_max, &_output_permuted, beta, &_tmp);
+ sm->configure(tmp_input, &_output_permuted, beta, is_log, &_tmp);
// Re-permute the permuted output into the requested (4D) output
_permute_output.configure(&_output_permuted, dst,
@@ -110,14 +100,15 @@ void CpuSoftmaxGeneric<IS_LOG>::configure(const ITensorInfo *src, ITensorInfo *d
else
{
// Softmax 2D case
- sm->configure(tmp_input, &_max, dst, beta, &_tmp);
+ sm->configure(tmp_input, dst, beta, is_log, &_tmp);
}
_softmax_kernel = std::move(sm);
- _aux_mem[InternalTensorIdx::MAX] =
- MemoryInfo(offset_int_vec(InternalTensorIdx::MAX), MemoryLifetime::Temporary, _max.total_size());
- _aux_mem[InternalTensorIdx::TMP] =
- MemoryInfo(offset_int_vec(InternalTensorIdx::TMP), MemoryLifetime::Temporary, _tmp.total_size());
+ if (_tmp.total_size() > 0)
+ {
+ _aux_mem[InternalTensorIdx::TMP] =
+ MemoryInfo(offset_int_vec(InternalTensorIdx::TMP), MemoryLifetime::Temporary, _tmp.total_size());
+ }
_aux_mem[InternalTensorIdx::PERMUTED_SRC] = MemoryInfo(offset_int_vec(InternalTensorIdx::PERMUTED_SRC),
MemoryLifetime::Temporary, _input_permuted.total_size());
@@ -125,8 +116,8 @@ void CpuSoftmaxGeneric<IS_LOG>::configure(const ITensorInfo *src, ITensorInfo *d
MemoryLifetime::Temporary, _output_permuted.total_size());
}
-template <bool IS_LOG>
-Status CpuSoftmaxGeneric<IS_LOG>::validate(const ITensorInfo *src, const ITensorInfo *dst, float beta, int32_t axis)
+Status
+CpuSoftmaxGeneric::validate(const ITensorInfo *src, const ITensorInfo *dst, float beta, int32_t axis, bool is_log)
{
// Perform validation step
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(src, dst);
@@ -136,17 +127,12 @@ Status CpuSoftmaxGeneric<IS_LOG>::validate(const ITensorInfo *src, const ITensor
static_cast<int32_t>(src->num_dimensions()) <= axis);
// Create intermediate tensor info
- DataType tmp_data_type = src->data_type();
- const TensorInfo tensor_info_tmp(src->clone()->set_data_type(tmp_data_type).set_is_resizable(true));
-
- TensorShape max_sum_shape = src->tensor_shape();
- max_sum_shape.set(0, 1);
- const TensorInfo tensor_info_max_sum(src->clone()
- ->set_tensor_shape(max_sum_shape)
- .set_data_type(tmp_data_type)
- .set_quantization_info(src->quantization_info())
- .set_is_resizable(true));
- const TensorInfo dont_care;
+ TensorInfo tensor_info_tmp;
+
+ if (is_data_type_quantized_asymmetric(src->data_type()))
+ {
+ tensor_info_tmp = src->clone()->set_data_type(DataType::F32).set_is_resizable(true);
+ }
const unsigned int actual_axis =
static_cast<unsigned int>(wrap_around(axis, static_cast<int32_t>(src->num_dimensions())));
@@ -165,15 +151,12 @@ Status CpuSoftmaxGeneric<IS_LOG>::validate(const ITensorInfo *src, const ITensor
ARM_COMPUTE_RETURN_ON_ERROR(CpuPermute::validate(&output_permuted, dst, permutation_vector));
}
- ARM_COMPUTE_RETURN_ON_ERROR(kernels::CpuLogits1DMaxKernel::validate(src, &tensor_info_max_sum));
- ARM_COMPUTE_RETURN_ON_ERROR(kernels::CpuLogits1DSoftmaxKernel<IS_LOG>::validate(
- &tensor_info_tmp, &tensor_info_max_sum, dst, beta, &dont_care));
+ ARM_COMPUTE_RETURN_ON_ERROR(kernels::CpuSoftmaxKernel::validate(src, dst, beta, is_log, &tensor_info_tmp));
return Status{};
}
-template <bool IS_LOG>
-void CpuSoftmaxGeneric<IS_LOG>::run(ITensorPack &tensors)
+void CpuSoftmaxGeneric::run(ITensorPack &tensors)
{
ARM_COMPUTE_ERROR_ON_MSG(tensors.empty(), "No inputs provided");
@@ -181,13 +164,11 @@ void CpuSoftmaxGeneric<IS_LOG>::run(ITensorPack &tensors)
auto dst = tensors.get_tensor(TensorType::ACL_DST);
CpuAuxTensorHandler tmp(offset_int_vec(InternalTensorIdx::TMP), _tmp, tensors, true);
- CpuAuxTensorHandler max(offset_int_vec(InternalTensorIdx::MAX), _max, tensors, true);
CpuAuxTensorHandler input_permuted(offset_int_vec(InternalTensorIdx::PERMUTED_SRC), _input_permuted, tensors, true);
CpuAuxTensorHandler output_permuted(offset_int_vec(InternalTensorIdx::PERMUTED_DST), _output_permuted, tensors,
true);
- ITensorPack max_pack;
ITensorPack softmax_pack;
if (_needs_permute)
@@ -195,24 +176,15 @@ void CpuSoftmaxGeneric<IS_LOG>::run(ITensorPack &tensors)
ITensorPack permute_in_pack = {{TensorType::ACL_SRC, src}, {TensorType::ACL_DST, input_permuted.get()}};
_permute_input.run(permute_in_pack);
- max_pack = {{TensorType::ACL_SRC, input_permuted.get()}, {TensorType::ACL_DST, max.get()}};
-
softmax_pack = {{TensorType::ACL_SRC_0, input_permuted.get()},
- {TensorType::ACL_SRC_1, max.get()},
{TensorType::ACL_DST_0, output_permuted.get()},
{TensorType::ACL_DST_1, tmp.get()}};
}
else
{
- max_pack = {{TensorType::ACL_SRC, src}, {TensorType::ACL_DST, max.get()}};
-
- softmax_pack = {{TensorType::ACL_SRC_0, src},
- {TensorType::ACL_SRC_1, max.get()},
- {TensorType::ACL_DST_0, dst},
- {TensorType::ACL_DST_1, tmp.get()}};
+ softmax_pack = {{TensorType::ACL_SRC_0, src}, {TensorType::ACL_DST_0, dst}, {TensorType::ACL_DST_1, tmp.get()}};
}
- NEScheduler::get().schedule_op(_max_kernel.get(), Window::DimY, _max_kernel->window(), max_pack);
NEScheduler::get().schedule_op(_softmax_kernel.get(), Window::DimY, _softmax_kernel->window(), softmax_pack);
if (_needs_permute)
@@ -224,13 +196,10 @@ void CpuSoftmaxGeneric<IS_LOG>::run(ITensorPack &tensors)
}
}
-template <bool IS_LOG>
-experimental::MemoryRequirements CpuSoftmaxGeneric<IS_LOG>::workspace() const
+experimental::MemoryRequirements CpuSoftmaxGeneric::workspace() const
{
return _aux_mem;
}
-template class CpuSoftmaxGeneric<false>;
-template class CpuSoftmaxGeneric<true>;
} // namespace cpu
} // namespace arm_compute