diff options
Diffstat (limited to 'src/gpu/cl/operators/ClSoftmax.h')
-rw-r--r-- | src/gpu/cl/operators/ClSoftmax.h | 45 |
1 files changed, 16 insertions, 29 deletions
diff --git a/src/gpu/cl/operators/ClSoftmax.h b/src/gpu/cl/operators/ClSoftmax.h index 6c2aaaea80..232fcfebd1 100644 --- a/src/gpu/cl/operators/ClSoftmax.h +++ b/src/gpu/cl/operators/ClSoftmax.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Arm Limited. + * Copyright (c) 2021, 2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -21,25 +21,26 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. */ -#ifndef ARM_COMPUTE_CL_SOFTMAX_H -#define ARM_COMPUTE_CL_SOFTMAX_H +#ifndef ACL_SRC_GPU_CL_OPERATORS_CLSOFTMAX_H +#define ACL_SRC_GPU_CL_OPERATORS_CLSOFTMAX_H +#include "arm_compute/core/experimental/Types.h" #include "arm_compute/runtime/CL/CLTensor.h" -#include "src/gpu/cl/ClCompileContext.h" #include "src/gpu/cl/IClOperator.h" namespace arm_compute { +class CLCompileContext; +class ITensorInfo; +class ITensorPack; struct SoftmaxKernelInfo; namespace opencl { -class ClPermute; namespace kernels { -class ClLogits1DMaxShiftExpSumKernel; -class ClLogits1DNormKernel; +class ClSoftmaxKernel; } // namespace kernels class ClSoftmax : public IClOperator { @@ -64,36 +65,22 @@ public: * @return a status */ static Status validate(const ITensorInfo &src, const ITensorInfo &dst, const SoftmaxKernelInfo &info); - // Inherited methods overridden: - void run(ITensorPack &tensors) override; + + void run(ITensorPack &tensors) override; + experimental::MemoryRequirements workspace() const override; private: enum InternalTensorIdx { - MAX = 0, - SUM, - TMP, - PERMUTED_SRC, - PERMUTED_DST, - COUNT + TMP = 0, + COUNT, }; - std::unique_ptr<ClPermute> _permute_input; - std::unique_ptr<ClPermute> _permute_output; - std::unique_ptr<kernels::ClLogits1DMaxShiftExpSumKernel> _max_shift_exp_sum_kernel; - std::unique_ptr<kernels::ClLogits1DNormKernel> _norm_kernel; - bool _needs_permute{false}; - - TensorInfo _max_info; - TensorInfo _sum_info; - TensorInfo _tmp_info; - TensorInfo _permuted_src_info; - TensorInfo _permuted_dst_info; - - experimental::MemoryRequirements _aux_mem{}; + TensorInfo _tmp_info{}; + experimental::MemoryRequirements _aux_mem; }; } // namespace opencl } // namespace arm_compute -#endif /* ARM_COMPUTE_CL_SOFTMAX_H */ +#endif // ACL_SRC_GPU_CL_OPERATORS_CLSOFTMAX_H |