diff options
Diffstat (limited to 'src/runtime/cpu/operators/CpuSoftmax.h')
-rw-r--r-- | src/runtime/cpu/operators/CpuSoftmax.h | 32 |
1 files changed, 22 insertions, 10 deletions
diff --git a/src/runtime/cpu/operators/CpuSoftmax.h b/src/runtime/cpu/operators/CpuSoftmax.h index 9f18e0e4c5..38817977b3 100644 --- a/src/runtime/cpu/operators/CpuSoftmax.h +++ b/src/runtime/cpu/operators/CpuSoftmax.h @@ -24,7 +24,7 @@ #ifndef ARM_COMPUTE_CPU_SOFTMAX_H #define ARM_COMPUTE_CPU_SOFTMAX_H -#include "arm_compute/core/ITensorInfo.h" +#include "arm_compute/core/TensorInfo.h" #include "arm_compute/core/experimental/Types.h" #include "src/core/cpu/ICpuKernel.h" #include "src/runtime/cpu/ICpuOperator.h" @@ -87,15 +87,27 @@ public: experimental::MemoryRequirements workspace() const override; private: - CpuPermute _permute_input; - CpuPermute _permute_output; - std::unique_ptr<ICpuKernel> _max_kernel; - std::unique_ptr<ICpuKernel> _softmax_kernel; - std::unique_ptr<ITensorInfo> _max; - std::unique_ptr<ITensorInfo> _tmp; - std::unique_ptr<ITensorInfo> _input_permuted; - std::unique_ptr<ITensorInfo> _output_permuted; - bool _needs_permute; + enum InternalTensorIdx + { + MAX = 0, + TMP, + PERMUTED_SRC, + PERMUTED_DST, + COUNT + }; + + CpuPermute _permute_input; + CpuPermute _permute_output; + std::unique_ptr<ICpuKernel> _max_kernel; + std::unique_ptr<ICpuKernel> _softmax_kernel; + + TensorInfo _max; + TensorInfo _tmp; + TensorInfo _input_permuted; + TensorInfo _output_permuted; + + bool _needs_permute; + experimental::MemoryRequirements _aux_mem{}; }; using CpuSoftmax = CpuSoftmaxGeneric<false>; using CpuLogSoftmax = CpuSoftmaxGeneric<true>; |