diff options
Diffstat (limited to 'src/runtime/NEON/functions/NESoftmaxLayer.cpp')
-rw-r--r-- | src/runtime/NEON/functions/NESoftmaxLayer.cpp | 7 |
1 files changed, 4 insertions, 3 deletions
diff --git a/src/runtime/NEON/functions/NESoftmaxLayer.cpp b/src/runtime/NEON/functions/NESoftmaxLayer.cpp index e763caa3a3..4f773861d2 100644 --- a/src/runtime/NEON/functions/NESoftmaxLayer.cpp +++ b/src/runtime/NEON/functions/NESoftmaxLayer.cpp @@ -27,6 +27,7 @@ #include "arm_compute/core/NEON/kernels/NESoftmaxLayerKernel.h" #include "arm_compute/core/utils/misc/ShapeCalculator.h" #include "arm_compute/runtime/NEON/NEScheduler.h" +#include "src/core/helpers/SoftmaxHelpers.h" namespace arm_compute { @@ -53,7 +54,7 @@ void NESoftmaxLayerGeneric<IS_LOG>::configure(ITensor *input, ITensor *output, f // Add to the memory manager _input_permuted _memory_group.manage(&_input_permuted); - _permute_input.configure(input, &_input_permuted, get_permutation_vector_from_softmax_axis(actual_axis)); + _permute_input.configure(input, &_input_permuted, softmax_helpers::get_permutation_vector_from_softmax_axis(actual_axis)); } // We want to deal with a 2D input. Either it is the permuted version of the original input (4D case) @@ -87,7 +88,7 @@ void NESoftmaxLayerGeneric<IS_LOG>::configure(ITensor *input, ITensor *output, f _input_permuted.allocator()->allocate(); // Re-permute the permuted output into the requested (4D) output - _permute_output.configure(&_output_permuted, output, get_permutation_vector_from_softmax_axis(actual_axis)); + _permute_output.configure(&_output_permuted, output, softmax_helpers::get_permutation_vector_from_softmax_axis(actual_axis)); // Allocate the intermediate permuted tensors _output_permuted.allocator()->allocate(); @@ -128,7 +129,7 @@ Status NESoftmaxLayerGeneric<IS_LOG>::validate(const ITensorInfo *input, const I if(needs_permute) { - const PermutationVector permutation_vector = get_permutation_vector_from_softmax_axis(actual_axis); + const PermutationVector permutation_vector = softmax_helpers::get_permutation_vector_from_softmax_axis(actual_axis); const TensorShape permuted_shape = misc::shape_calculator::compute_permutation_output_shape(*input, permutation_vector); TensorInfo input_permuted(input->clone()->set_tensor_shape(permuted_shape)); ARM_COMPUTE_RETURN_ON_ERROR(NEPermute::validate(input, &input_permuted, permutation_vector)); |