aboutsummaryrefslogtreecommitdiff
path: root/src/runtime/NEON/functions/NESoftmaxLayer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/runtime/NEON/functions/NESoftmaxLayer.cpp')
-rw-r--r--src/runtime/NEON/functions/NESoftmaxLayer.cpp7
1 files changed, 4 insertions, 3 deletions
diff --git a/src/runtime/NEON/functions/NESoftmaxLayer.cpp b/src/runtime/NEON/functions/NESoftmaxLayer.cpp
index e763caa3a3..4f773861d2 100644
--- a/src/runtime/NEON/functions/NESoftmaxLayer.cpp
+++ b/src/runtime/NEON/functions/NESoftmaxLayer.cpp
@@ -27,6 +27,7 @@
#include "arm_compute/core/NEON/kernels/NESoftmaxLayerKernel.h"
#include "arm_compute/core/utils/misc/ShapeCalculator.h"
#include "arm_compute/runtime/NEON/NEScheduler.h"
+#include "src/core/helpers/SoftmaxHelpers.h"
namespace arm_compute
{
@@ -53,7 +54,7 @@ void NESoftmaxLayerGeneric<IS_LOG>::configure(ITensor *input, ITensor *output, f
// Add to the memory manager _input_permuted
_memory_group.manage(&_input_permuted);
- _permute_input.configure(input, &_input_permuted, get_permutation_vector_from_softmax_axis(actual_axis));
+ _permute_input.configure(input, &_input_permuted, softmax_helpers::get_permutation_vector_from_softmax_axis(actual_axis));
}
// We want to deal with a 2D input. Either it is the permuted version of the original input (4D case)
@@ -87,7 +88,7 @@ void NESoftmaxLayerGeneric<IS_LOG>::configure(ITensor *input, ITensor *output, f
_input_permuted.allocator()->allocate();
// Re-permute the permuted output into the requested (4D) output
- _permute_output.configure(&_output_permuted, output, get_permutation_vector_from_softmax_axis(actual_axis));
+ _permute_output.configure(&_output_permuted, output, softmax_helpers::get_permutation_vector_from_softmax_axis(actual_axis));
// Allocate the intermediate permuted tensors
_output_permuted.allocator()->allocate();
@@ -128,7 +129,7 @@ Status NESoftmaxLayerGeneric<IS_LOG>::validate(const ITensorInfo *input, const I
if(needs_permute)
{
- const PermutationVector permutation_vector = get_permutation_vector_from_softmax_axis(actual_axis);
+ const PermutationVector permutation_vector = softmax_helpers::get_permutation_vector_from_softmax_axis(actual_axis);
const TensorShape permuted_shape = misc::shape_calculator::compute_permutation_output_shape(*input, permutation_vector);
TensorInfo input_permuted(input->clone()->set_tensor_shape(permuted_shape));
ARM_COMPUTE_RETURN_ON_ERROR(NEPermute::validate(input, &input_permuted, permutation_vector));