aboutsummaryrefslogtreecommitdiff
path: root/src/runtime/gpu/cl/operators/ClSoftmax.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/runtime/gpu/cl/operators/ClSoftmax.h')
-rw-r--r--src/runtime/gpu/cl/operators/ClSoftmax.h36
1 files changed, 7 insertions, 29 deletions
diff --git a/src/runtime/gpu/cl/operators/ClSoftmax.h b/src/runtime/gpu/cl/operators/ClSoftmax.h
index e38b7c595a..f19a51fc5e 100644
--- a/src/runtime/gpu/cl/operators/ClSoftmax.h
+++ b/src/runtime/gpu/cl/operators/ClSoftmax.h
@@ -67,7 +67,7 @@ public:
experimental::MemoryRequirements workspace() const override;
private:
- enum class InternalTensorIdx
+ enum InternalTensorIdx
{
MAX = 0,
SUM,
@@ -77,41 +77,19 @@ private:
COUNT
};
- /** Create a single internal tensor
- *
- * @param[in] info The information used to create a tensor
- * @param[in] idx The index within the internal array the created tensor will be held
- */
- void create_internal_tensor(TensorInfo &info, InternalTensorIdx idx);
- /** Create all required internal tensors */
- void create_internal_tensor();
- /** Function to convert from internal tensor index to @ref TensorType used externally */
- TensorType convert_internal_idx_to_tensor_type(InternalTensorIdx idx) const;
- /** Function to import workspace memory allocated by the caller into internal tensor instances */
- void import_workspace_memory(ITensorPack &tensors);
- /** Function to permute the given source tensor when permutation is required */
- void run_source_permute(const ITensor *src);
- /** Function to permute the intemediate tensor to the final destination tensor when permutation is required */
- void run_destination_permute(ITensor *dst);
- /** Function to run @ref arm_compute::opencl::kernels::ClLogits1DMaxShiftExpSumKernel */
- void run_max_sum(const ITensor *src);
- /** Function to run @ref kernels::ClLogits1DNormKernel */
- void run_norm(ITensor *dst);
-
std::unique_ptr<ClPermute> _permute_input;
std::unique_ptr<ClPermute> _permute_output;
std::unique_ptr<kernels::ClLogits1DMaxShiftExpSumKernel> _max_shift_exp_sum_kernel;
std::unique_ptr<kernels::ClLogits1DNormKernel> _norm_kernel;
bool _needs_permute{ false };
- std::array<TensorInfo, static_cast<uint32_t>(InternalTensorIdx::COUNT)> _internal_info{};
- std::array<std::unique_ptr<CLTensor>, static_cast<uint32_t>(InternalTensorIdx::COUNT)> _internal_tensor{};
+ TensorInfo _max_info;
+ TensorInfo _sum_info;
+ TensorInfo _tmp_info;
+ TensorInfo _permuted_src_info;
+ TensorInfo _permuted_dst_info;
- TensorInfo &_max_info;
- TensorInfo &_sum_info;
- TensorInfo &_tmp_info;
- TensorInfo &_permuted_src_info;
- TensorInfo &_permuted_dst_info;
+ experimental::MemoryRequirements _aux_mem{};
};
} // opencl