aboutsummaryrefslogtreecommitdiff
path: root/src/gpu/cl/operators/ClSoftmax.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/gpu/cl/operators/ClSoftmax.h')
-rw-r--r--src/gpu/cl/operators/ClSoftmax.h45
1 files changed, 16 insertions, 29 deletions
diff --git a/src/gpu/cl/operators/ClSoftmax.h b/src/gpu/cl/operators/ClSoftmax.h
index 6c2aaaea80..232fcfebd1 100644
--- a/src/gpu/cl/operators/ClSoftmax.h
+++ b/src/gpu/cl/operators/ClSoftmax.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021, 2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -21,25 +21,26 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
-#ifndef ARM_COMPUTE_CL_SOFTMAX_H
-#define ARM_COMPUTE_CL_SOFTMAX_H
+#ifndef ACL_SRC_GPU_CL_OPERATORS_CLSOFTMAX_H
+#define ACL_SRC_GPU_CL_OPERATORS_CLSOFTMAX_H
+#include "arm_compute/core/experimental/Types.h"
#include "arm_compute/runtime/CL/CLTensor.h"
-#include "src/gpu/cl/ClCompileContext.h"
#include "src/gpu/cl/IClOperator.h"
namespace arm_compute
{
+class CLCompileContext;
+class ITensorInfo;
+class ITensorPack;
struct SoftmaxKernelInfo;
namespace opencl
{
-class ClPermute;
namespace kernels
{
-class ClLogits1DMaxShiftExpSumKernel;
-class ClLogits1DNormKernel;
+class ClSoftmaxKernel;
} // namespace kernels
class ClSoftmax : public IClOperator
{
@@ -64,36 +65,22 @@ public:
* @return a status
*/
static Status validate(const ITensorInfo &src, const ITensorInfo &dst, const SoftmaxKernelInfo &info);
- // Inherited methods overridden:
- void run(ITensorPack &tensors) override;
+
+ void run(ITensorPack &tensors) override;
+
experimental::MemoryRequirements workspace() const override;
private:
enum InternalTensorIdx
{
- MAX = 0,
- SUM,
- TMP,
- PERMUTED_SRC,
- PERMUTED_DST,
- COUNT
+ TMP = 0,
+ COUNT,
};
- std::unique_ptr<ClPermute> _permute_input;
- std::unique_ptr<ClPermute> _permute_output;
- std::unique_ptr<kernels::ClLogits1DMaxShiftExpSumKernel> _max_shift_exp_sum_kernel;
- std::unique_ptr<kernels::ClLogits1DNormKernel> _norm_kernel;
- bool _needs_permute{false};
-
- TensorInfo _max_info;
- TensorInfo _sum_info;
- TensorInfo _tmp_info;
- TensorInfo _permuted_src_info;
- TensorInfo _permuted_dst_info;
-
- experimental::MemoryRequirements _aux_mem{};
+ TensorInfo _tmp_info{};
+ experimental::MemoryRequirements _aux_mem;
};
} // namespace opencl
} // namespace arm_compute
-#endif /* ARM_COMPUTE_CL_SOFTMAX_H */
+#endif // ACL_SRC_GPU_CL_OPERATORS_CLSOFTMAX_H