aboutsummaryrefslogtreecommitdiff
path: root/src/runtime/cpu/operators/CpuSoftmax.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/runtime/cpu/operators/CpuSoftmax.h')
-rw-r--r--src/runtime/cpu/operators/CpuSoftmax.h32
1 files changed, 22 insertions, 10 deletions
diff --git a/src/runtime/cpu/operators/CpuSoftmax.h b/src/runtime/cpu/operators/CpuSoftmax.h
index 9f18e0e4c5..38817977b3 100644
--- a/src/runtime/cpu/operators/CpuSoftmax.h
+++ b/src/runtime/cpu/operators/CpuSoftmax.h
@@ -24,7 +24,7 @@
#ifndef ARM_COMPUTE_CPU_SOFTMAX_H
#define ARM_COMPUTE_CPU_SOFTMAX_H
-#include "arm_compute/core/ITensorInfo.h"
+#include "arm_compute/core/TensorInfo.h"
#include "arm_compute/core/experimental/Types.h"
#include "src/core/cpu/ICpuKernel.h"
#include "src/runtime/cpu/ICpuOperator.h"
@@ -87,15 +87,27 @@ public:
experimental::MemoryRequirements workspace() const override;
private:
- CpuPermute _permute_input;
- CpuPermute _permute_output;
- std::unique_ptr<ICpuKernel> _max_kernel;
- std::unique_ptr<ICpuKernel> _softmax_kernel;
- std::unique_ptr<ITensorInfo> _max;
- std::unique_ptr<ITensorInfo> _tmp;
- std::unique_ptr<ITensorInfo> _input_permuted;
- std::unique_ptr<ITensorInfo> _output_permuted;
- bool _needs_permute;
+ enum InternalTensorIdx
+ {
+ MAX = 0,
+ TMP,
+ PERMUTED_SRC,
+ PERMUTED_DST,
+ COUNT
+ };
+
+ CpuPermute _permute_input;
+ CpuPermute _permute_output;
+ std::unique_ptr<ICpuKernel> _max_kernel;
+ std::unique_ptr<ICpuKernel> _softmax_kernel;
+
+ TensorInfo _max;
+ TensorInfo _tmp;
+ TensorInfo _input_permuted;
+ TensorInfo _output_permuted;
+
+ bool _needs_permute;
+ experimental::MemoryRequirements _aux_mem{};
};
using CpuSoftmax = CpuSoftmaxGeneric<false>;
using CpuLogSoftmax = CpuSoftmaxGeneric<true>;