aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorManuel Bottini <manuel.bottini@arm.com>2021-06-09 16:37:32 +0100
committerManuel Bottini <manuel.bottini@arm.com>2021-06-15 16:31:27 +0000
commit94f799e8f6f605333d40472860fb472e8ba6d83d (patch)
treece528244814463ed42dc86a84d54ea870c75d592
parent36dff9f81e3a95aea19fcc7246a4896930a14bc6 (diff)
downloadComputeLibrary-94f799e8f6f605333d40472860fb472e8ba6d83d.tar.gz
Fix incorrect memory handling in ported functions
Details of the functions: - ClSoftmax - CpuSoftmax - CpuPool2d Change-Id: Icd2c14d5df010c3b2301e2693ce6f414d7c61916 Resolves: COMPMID-4404 Signed-off-by: Manuel Bottini <manuel.bottini@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/5797 Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
-rw-r--r--arm_compute/runtime/CL/functions/CLSoftmaxLayer.h3
-rw-r--r--arm_compute/runtime/NEON/functions/NEPoolingLayer.h2
-rw-r--r--arm_compute/runtime/NEON/functions/NESoftmaxLayer.h3
-rw-r--r--src/core/helpers/MemoryHelpers.h9
-rw-r--r--src/runtime/CL/functions/CLSoftmaxLayer.cpp44
-rw-r--r--src/runtime/NEON/functions/NEPoolingLayer.cpp27
-rw-r--r--src/runtime/NEON/functions/NESoftmaxLayer.cpp83
-rw-r--r--src/runtime/cpu/operators/CpuPool2d.cpp8
-rw-r--r--src/runtime/cpu/operators/CpuPool2d.h2
-rw-r--r--src/runtime/cpu/operators/CpuSoftmax.cpp99
-rw-r--r--src/runtime/cpu/operators/CpuSoftmax.h32
-rw-r--r--src/runtime/cpu/utils/CpuAuxTensorHandler.h101
-rw-r--r--src/runtime/gpu/cl/operators/ClSoftmax.cpp194
-rw-r--r--src/runtime/gpu/cl/operators/ClSoftmax.h36
14 files changed, 286 insertions, 357 deletions
diff --git a/arm_compute/runtime/CL/functions/CLSoftmaxLayer.h b/arm_compute/runtime/CL/functions/CLSoftmaxLayer.h
index 721a47144e..687f8ff6d8 100644
--- a/arm_compute/runtime/CL/functions/CLSoftmaxLayer.h
+++ b/arm_compute/runtime/CL/functions/CLSoftmaxLayer.h
@@ -106,9 +106,6 @@ public:
private:
struct Impl;
std::unique_ptr<Impl> _impl;
-
- /** Allocate workspace required by the operator */
- void allocate_workspace();
};
using CLSoftmaxLayer = CLSoftmaxLayerGeneric<false>;
diff --git a/arm_compute/runtime/NEON/functions/NEPoolingLayer.h b/arm_compute/runtime/NEON/functions/NEPoolingLayer.h
index b5366fa1c1..9398e1fce9 100644
--- a/arm_compute/runtime/NEON/functions/NEPoolingLayer.h
+++ b/arm_compute/runtime/NEON/functions/NEPoolingLayer.h
@@ -95,8 +95,6 @@ public:
void run() override;
private:
- MemoryGroup _memory_group;
-
struct Impl;
std::unique_ptr<Impl> _impl;
};
diff --git a/arm_compute/runtime/NEON/functions/NESoftmaxLayer.h b/arm_compute/runtime/NEON/functions/NESoftmaxLayer.h
index efe959f14e..02d0cc15b2 100644
--- a/arm_compute/runtime/NEON/functions/NESoftmaxLayer.h
+++ b/arm_compute/runtime/NEON/functions/NESoftmaxLayer.h
@@ -25,7 +25,7 @@
#define ARM_COMPUTE_NESOFTMAXLAYER_H
#include "arm_compute/runtime/IFunction.h"
-#include "arm_compute/runtime/MemoryGroup.h"
+#include "src/core/helpers/MemoryHelpers.h"
#include <memory>
namespace arm_compute
@@ -88,7 +88,6 @@ public:
void run() override;
private:
- MemoryGroup _memory_group;
struct Impl;
std::unique_ptr<Impl> _impl;
};
diff --git a/src/core/helpers/MemoryHelpers.h b/src/core/helpers/MemoryHelpers.h
index 6756a90c25..e751e6025d 100644
--- a/src/core/helpers/MemoryHelpers.h
+++ b/src/core/helpers/MemoryHelpers.h
@@ -46,6 +46,15 @@ using WorkspaceData = std::vector<std::pair<int, std::unique_ptr<TensorType>>>;
template <typename TensorType>
WorkspaceData<TensorType> manage_workspace(const experimental::MemoryRequirements &mem_reqs,
MemoryGroup &mgroup,
+ ITensorPack &run_pack)
+{
+ ITensorPack dummy_pack = ITensorPack();
+ return manage_workspace<TensorType>(mem_reqs, mgroup, run_pack, dummy_pack);
+}
+
+template <typename TensorType>
+WorkspaceData<TensorType> manage_workspace(const experimental::MemoryRequirements &mem_reqs,
+ MemoryGroup &mgroup,
ITensorPack &run_pack, ITensorPack &prep_pack)
{
WorkspaceData<TensorType> workspace_memory;
diff --git a/src/runtime/CL/functions/CLSoftmaxLayer.cpp b/src/runtime/CL/functions/CLSoftmaxLayer.cpp
index fe45f65beb..de58bf1b02 100644
--- a/src/runtime/CL/functions/CLSoftmaxLayer.cpp
+++ b/src/runtime/CL/functions/CLSoftmaxLayer.cpp
@@ -29,6 +29,7 @@
#include "arm_compute/core/Types.h"
#include "arm_compute/core/Utils.h"
#include "src/core/gpu/cl/kernels/ClSoftmaxKernel.h"
+#include "src/core/helpers/MemoryHelpers.h"
#include "src/runtime/gpu/cl/operators/ClPermute.h"
#include "src/runtime/gpu/cl/operators/ClSoftmax.h"
@@ -43,7 +44,8 @@ struct CLSoftmaxLayerGeneric<IS_LOG>::Impl
ICLTensor *dst{ nullptr };
std::unique_ptr<OperatorType> op{ nullptr };
MemoryGroup memory_group{};
- std::vector<std::pair<int, std::unique_ptr<CLTensor>>> workspace_tensors{};
+ ITensorPack run_pack{};
+ WorkspaceData<CLTensor> workspace_tensors{};
};
template <bool IS_LOG>
@@ -71,7 +73,9 @@ void CLSoftmaxLayerGeneric<IS_LOG>::configure(const CLCompileContext &compile_co
SoftmaxKernelInfo softmax_info{ beta, IS_LOG, input->info()->data_type(), axis };
_impl->op->configure(compile_context, *input->info(), *output->info(), softmax_info);
- allocate_workspace();
+
+ _impl->run_pack = { { TensorType::ACL_SRC, _impl->src }, { TensorType::ACL_DST, _impl->dst } };
+ _impl->workspace_tensors = manage_workspace<CLTensor>(_impl->op->workspace(), _impl->memory_group, _impl->run_pack);
}
template <bool IS_LOG>
@@ -82,46 +86,12 @@ Status CLSoftmaxLayerGeneric<IS_LOG>::validate(const ITensorInfo *input, const I
}
template <bool IS_LOG>
-void CLSoftmaxLayerGeneric<IS_LOG>::allocate_workspace()
-{
- const auto memory_requirements = _impl->op->workspace();
- std::for_each(memory_requirements.begin(), memory_requirements.end(), [this](const experimental::MemoryInfo & memory_info)
- {
- auto tensor_info = TensorInfo{ TensorShape(memory_info.size), 1, DataType::U8 };
- _impl->workspace_tensors.emplace_back(memory_info.slot, std::make_unique<CLTensor>());
- auto tensor = _impl->workspace_tensors.back().second.get();
- ARM_COMPUTE_ERROR_ON_NULLPTR(tensor);
- tensor->allocator()->init(tensor_info);
- _impl->memory_group.manage(tensor);
- });
-
- std::for_each(_impl->workspace_tensors.begin(), _impl->workspace_tensors.end(), [](std::pair<int, std::unique_ptr<CLTensor>> &wt)
- {
- auto tensor = wt.second.get();
- tensor->allocator()->allocate();
- });
-}
-
-template <bool IS_LOG>
void CLSoftmaxLayerGeneric<IS_LOG>::run()
{
// Acquire all the temporaries
MemoryGroupResourceScope scope_mg(_impl->memory_group);
-
ARM_COMPUTE_ERROR_ON_NULLPTR(_impl->src, _impl->dst);
-
- ITensorPack pack;
- pack.add_tensor(TensorType::ACL_SRC, _impl->src);
- pack.add_tensor(TensorType::ACL_DST, _impl->dst);
-
- std::for_each(_impl->workspace_tensors.begin(), _impl->workspace_tensors.end(), [&pack](std::pair<int, std::unique_ptr<CLTensor>> &wt)
- {
- auto tensor = wt.second.get();
- ARM_COMPUTE_ERROR_ON_NULLPTR(tensor);
- pack.add_tensor(wt.first, tensor);
- });
-
- _impl->op->run(pack);
+ _impl->op->run(_impl->run_pack);
}
template class CLSoftmaxLayerGeneric<false>;
diff --git a/src/runtime/NEON/functions/NEPoolingLayer.cpp b/src/runtime/NEON/functions/NEPoolingLayer.cpp
index bbf3e7cc4e..8d267a32c0 100644
--- a/src/runtime/NEON/functions/NEPoolingLayer.cpp
+++ b/src/runtime/NEON/functions/NEPoolingLayer.cpp
@@ -26,6 +26,7 @@
#include "arm_compute/core/TensorInfo.h"
#include "arm_compute/core/Validate.h"
#include "arm_compute/runtime/Tensor.h"
+#include "src/core/helpers/MemoryHelpers.h"
#include "src/runtime/cpu/operators/CpuPool2d.h"
namespace arm_compute
@@ -35,15 +36,18 @@ struct NEPoolingLayer::Impl
ITensor *src{ nullptr };
ITensor *dst{ nullptr };
ITensor *indices{ nullptr };
- Tensor workspace{ nullptr };
std::unique_ptr<cpu::CpuPool2d> op{ nullptr };
+ MemoryGroup memory_group{};
+ ITensorPack run_pack{};
+ WorkspaceData<Tensor> workspace_tensors{};
};
NEPoolingLayer::~NEPoolingLayer() = default;
NEPoolingLayer::NEPoolingLayer(std::shared_ptr<IMemoryManager> memory_manager)
- : _memory_group(memory_manager), _impl(std::make_unique<Impl>())
+ : _impl(std::make_unique<Impl>())
{
+ _impl->memory_group = MemoryGroup(std::move(memory_manager));
}
void NEPoolingLayer::configure(ITensor *input, ITensor *output, const PoolingLayerInfo &pool_info, ITensor *indices)
@@ -54,14 +58,8 @@ void NEPoolingLayer::configure(ITensor *input, ITensor *output, const PoolingLay
_impl->op = std::make_unique<cpu::CpuPool2d>();
_impl->op->configure(input->info(), output->info(), pool_info, (indices) ? indices->info() : nullptr);
- // Allocate workspace based on kernel's memory requirements
- const experimental::MemoryRequirements mem_req = _impl->op->workspace();
- if(!mem_req.empty())
- {
- _impl->workspace.allocator()->init(TensorInfo(TensorShape{ (mem_req[0].size + mem_req[0].alignment) }, 1, DataType::S8), mem_req[0].alignment);
- _memory_group.manage(&_impl->workspace);
- _impl->workspace.allocator()->allocate();
- }
+ _impl->run_pack = { { TensorType::ACL_SRC, _impl->src }, { TensorType::ACL_DST_0, _impl->dst }, { TensorType::ACL_DST_1, _impl->indices } };
+ _impl->workspace_tensors = manage_workspace<Tensor>(_impl->op->workspace(), _impl->memory_group, _impl->run_pack);
}
Status NEPoolingLayer::validate(const ITensorInfo *input, const ITensorInfo *output, const PoolingLayerInfo &pool_info, const ITensorInfo *indices)
@@ -71,11 +69,8 @@ Status NEPoolingLayer::validate(const ITensorInfo *input, const ITensorInfo *out
void NEPoolingLayer::run()
{
- ITensorPack pack;
- pack.add_tensor(TensorType::ACL_SRC, _impl->src);
- pack.add_tensor(TensorType::ACL_DST_0, _impl->dst);
- pack.add_tensor(TensorType::ACL_DST_1, _impl->indices);
- pack.add_tensor(TensorType::ACL_INT_0, &_impl->workspace);
- _impl->op->run(pack);
+ MemoryGroupResourceScope scope_mg(_impl->memory_group);
+ ARM_COMPUTE_ERROR_ON_NULLPTR(_impl->src, _impl->dst);
+ _impl->op->run(_impl->run_pack);
}
} // namespace arm_compute
diff --git a/src/runtime/NEON/functions/NESoftmaxLayer.cpp b/src/runtime/NEON/functions/NESoftmaxLayer.cpp
index 3f1e43a8f2..af8546d4ca 100644
--- a/src/runtime/NEON/functions/NESoftmaxLayer.cpp
+++ b/src/runtime/NEON/functions/NESoftmaxLayer.cpp
@@ -23,6 +23,7 @@
*/
#include "arm_compute/runtime/NEON/functions/NESoftmaxLayer.h"
#include "arm_compute/core/Validate.h"
+#include "arm_compute/runtime/MemoryGroup.h"
#include "arm_compute/runtime/Tensor.h"
#include "src/core/cpu/kernels/CpuSoftmaxKernel.h"
#include "src/core/helpers/SoftmaxHelpers.h"
@@ -36,16 +37,17 @@ struct NESoftmaxLayerGeneric<IS_LOG>::Impl
const ITensor *src{ nullptr };
ITensor *dst{ nullptr };
Tensor max{ nullptr };
- Tensor tmp{ nullptr };
- Tensor input_permuted{ nullptr };
- Tensor output_permuted{ nullptr };
std::unique_ptr<cpu::CpuSoftmaxGeneric<IS_LOG>> op{ nullptr };
+ MemoryGroup memory_group{};
+ ITensorPack run_pack{};
+ WorkspaceData<Tensor> workspace_tensors{};
};
template <bool IS_LOG>
NESoftmaxLayerGeneric<IS_LOG>::NESoftmaxLayerGeneric(std::shared_ptr<IMemoryManager> memory_manager)
- : _memory_group(std::move(memory_manager)), _impl(std::make_unique<Impl>())
+ : _impl(std::make_unique<Impl>())
{
+ _impl->memory_group = MemoryGroup(std::move(memory_manager));
}
template <bool IS_LOG>
@@ -65,64 +67,8 @@ void NESoftmaxLayerGeneric<IS_LOG>::configure(ITensor *input, ITensor *output, f
_impl->op = std::make_unique<cpu::CpuSoftmaxGeneric<IS_LOG>>();
_impl->op->configure(input->info(), output->info(), beta, axis);
- const unsigned int actual_axis = static_cast<unsigned int>(wrap_around(axis, static_cast<int32_t>(input->info()->num_dimensions())));
- const bool needs_permute = actual_axis > 0;
- if(needs_permute)
- {
- // Add to the memory manager _input_permuted
- auto permute_input = std::make_unique<cpu::CpuPermute>();
- _memory_group.manage(&_impl->input_permuted);
- permute_input->configure(input->info(), _impl->input_permuted.info(), softmax_helpers::get_permutation_vector_from_softmax_axis(actual_axis));
- }
-
- // We want to deal with a 2D input. Either it is the permuted version of the original input (4D case)
- // or it is the original input case (2D case)
- ITensor *tmp_input = (needs_permute ? &_impl->input_permuted : input);
-
- // Create intermediate tensors shapes
- const TensorInfo input_info = tmp_input->info()->clone()->reset_padding().set_is_resizable(true);
- DataType tmp_data_type = is_data_type_quantized_asymmetric(tmp_input->info()->data_type()) ? DataType::F32 : tmp_input->info()->data_type();
- TensorInfo tensor_info_tmp(input_info.clone()->set_data_type(tmp_data_type));
-
- // Init intermediate tensors
- TensorShape max_sum_shape = tmp_input->info()->tensor_shape();
- max_sum_shape.set(0, 1);
- _impl->max.allocator()->init(input_info.clone()->set_tensor_shape(max_sum_shape));
- _impl->tmp.allocator()->init(tensor_info_tmp);
-
- // Manage intermediate buffers
- _memory_group.manage(&_impl->max);
- _memory_group.manage(&_impl->tmp);
-
- // Configure kernels
- auto max_kernel = std::make_unique<cpu::kernels::CpuLogits1DMaxKernel>();
- auto softmax_kernel = std::make_unique<cpu::kernels::CpuLogits1DSoftmaxKernel<IS_LOG>>();
- max_kernel->configure(tmp_input->info(), _impl->max.info());
-
- if(needs_permute)
- {
- auto permute_output = std::make_unique<cpu::CpuPermute>();
- // Add to the memory manager _output_permuted
- _memory_group.manage(&_impl->output_permuted);
-
- // The normalization kernel stores the result in a permuted output tensor
- softmax_kernel->configure(tmp_input->info(), _impl->max.info(), _impl->output_permuted.info(), beta, _impl->tmp.info());
- _impl->input_permuted.allocator()->allocate();
-
- // Re-permute the permuted output into the requested (4D) output
- permute_output->configure(_impl->output_permuted.info(), output->info(), softmax_helpers::get_permutation_vector_from_softmax_axis(actual_axis));
-
- // Allocate the intermediate permuted tensors
- _impl->output_permuted.allocator()->allocate();
- }
- else
- {
- softmax_kernel->configure(tmp_input->info(), _impl->max.info(), output->info(), beta, _impl->tmp.info());
- }
-
- // Allocate intermediate buffers
- _impl->max.allocator()->allocate();
- _impl->tmp.allocator()->allocate();
+ _impl->run_pack = { { TensorType::ACL_SRC, _impl->src }, { TensorType::ACL_DST, _impl->dst } };
+ _impl->workspace_tensors = manage_workspace<Tensor>(_impl->op->workspace(), _impl->memory_group, _impl->run_pack);
}
template <bool IS_LOG>
@@ -136,15 +82,10 @@ Status NESoftmaxLayerGeneric<IS_LOG>::validate(const ITensorInfo *input, const I
template <bool IS_LOG>
void NESoftmaxLayerGeneric<IS_LOG>::run()
{
- MemoryGroupResourceScope scope_mg(_memory_group);
- ITensorPack pack;
- pack.add_tensor(TensorType::ACL_SRC, _impl->src);
- pack.add_tensor(TensorType::ACL_DST, _impl->dst);
- pack.add_tensor(TensorType::ACL_INT_0, &_impl->tmp);
- pack.add_tensor(TensorType::ACL_INT_1, &_impl->max);
- pack.add_tensor(TensorType::ACL_INT_2, &_impl->input_permuted);
- pack.add_tensor(TensorType::ACL_INT_3, &_impl->output_permuted);
- _impl->op->run(pack);
+ // Acquire all the temporaries
+ MemoryGroupResourceScope scope_mg(_impl->memory_group);
+ ARM_COMPUTE_ERROR_ON_NULLPTR(_impl->src, _impl->dst);
+ _impl->op->run(_impl->run_pack);
}
template class NESoftmaxLayerGeneric<false>;
diff --git a/src/runtime/cpu/operators/CpuPool2d.cpp b/src/runtime/cpu/operators/CpuPool2d.cpp
index b225199c40..e746c8fb3b 100644
--- a/src/runtime/cpu/operators/CpuPool2d.cpp
+++ b/src/runtime/cpu/operators/CpuPool2d.cpp
@@ -30,6 +30,8 @@
#include "src/core/cpu/kernels/CpuPool2dKernel.h"
#include "src/core/cpu/kernels/internal/CpuPool2dAssemblyWrapperKernel.h"
+using namespace arm_compute::experimental;
+
namespace arm_compute
{
namespace cpu
@@ -40,7 +42,7 @@ CpuPool2d::CpuPool2d()
_asm_glue(),
_is_global_pooling_layer(false),
_data_layout(DataLayout::NCHW),
- _mem_req()
+ _aux_mem(1)
{
}
@@ -71,7 +73,7 @@ void CpuPool2d::configure(ITensorInfo *src, ITensorInfo *dst, const PoolingLayer
// Get kernel's memory requirements
constexpr size_t alignment = 4096;
const size_t workspace_size = pooling_wrapper->get_working_size(num_threads);
- _mem_req.push_back({ TensorType::ACL_INT_0, workspace_size, alignment });
+ _aux_mem[0] = MemoryInfo(TensorType::ACL_INT_0, MemoryLifetime::Temporary, workspace_size, alignment);
_asm_glue = std::move(pooling_wrapper);
}
@@ -150,7 +152,7 @@ void CpuPool2d::run(ITensorPack &tensors)
experimental::MemoryRequirements CpuPool2d::workspace() const
{
- return _mem_req;
+ return _aux_mem;
}
} // namespace cpu
} // namespace arm_compute
diff --git a/src/runtime/cpu/operators/CpuPool2d.h b/src/runtime/cpu/operators/CpuPool2d.h
index ae3d115dfc..68416b5cfc 100644
--- a/src/runtime/cpu/operators/CpuPool2d.h
+++ b/src/runtime/cpu/operators/CpuPool2d.h
@@ -80,7 +80,7 @@ private:
bool _is_global_pooling_layer;
DataLayout _data_layout;
- experimental::MemoryRequirements _mem_req;
+ experimental::MemoryRequirements _aux_mem{};
};
} // namespace cpu
} // namespace arm_compute
diff --git a/src/runtime/cpu/operators/CpuSoftmax.cpp b/src/runtime/cpu/operators/CpuSoftmax.cpp
index 0e1bcd5c69..e17925ee50 100644
--- a/src/runtime/cpu/operators/CpuSoftmax.cpp
+++ b/src/runtime/cpu/operators/CpuSoftmax.cpp
@@ -29,7 +29,11 @@
#include "arm_compute/core/utils/misc/ShapeCalculator.h"
#include "arm_compute/runtime/NEON/NEScheduler.h"
#include "src/core/cpu/kernels/CpuSoftmaxKernel.h"
+#include "src/core/helpers/MemoryHelpers.h"
#include "src/core/helpers/SoftmaxHelpers.h"
+#include "src/runtime/cpu/utils/CpuAuxTensorHandler.h"
+
+using namespace arm_compute::experimental;
namespace arm_compute
{
@@ -37,7 +41,16 @@ namespace cpu
{
template <bool IS_LOG>
CpuSoftmaxGeneric<IS_LOG>::CpuSoftmaxGeneric()
- : _permute_input(), _permute_output(), _max_kernel(), _softmax_kernel(), _max(nullptr), _tmp(nullptr), _input_permuted(nullptr), _output_permuted(nullptr), _needs_permute(false)
+ : _permute_input(),
+ _permute_output(),
+ _max_kernel(),
+ _softmax_kernel(),
+ _max(),
+ _tmp(),
+ _input_permuted(),
+ _output_permuted(),
+ _needs_permute(false),
+ _aux_mem(InternalTensorIdx::COUNT)
{
}
@@ -54,13 +67,12 @@ void CpuSoftmaxGeneric<IS_LOG>::configure(const ITensorInfo *src, ITensorInfo *d
if(_needs_permute)
{
- _input_permuted = std::make_unique<TensorInfo>();
- _permute_input.configure(src, _input_permuted.get(), softmax_helpers::get_permutation_vector_from_softmax_axis(actual_axis));
+ _permute_input.configure(src, &_input_permuted, softmax_helpers::get_permutation_vector_from_softmax_axis(actual_axis));
}
// We want to deal with a 2D input. Either it is the permuted version of the original input (4D case)
// or it is the original input case (2D case)
- const ITensorInfo *tmp_input = (_needs_permute ? _input_permuted.get() : src);
+ const ITensorInfo *tmp_input = (_needs_permute ? &_input_permuted : src);
// Create intermediate tensors shapes
TensorShape max_sum_shape = tmp_input->tensor_shape();
@@ -71,31 +83,35 @@ void CpuSoftmaxGeneric<IS_LOG>::configure(const ITensorInfo *src, ITensorInfo *d
TensorInfo max_info(tmp_input->clone()->set_tensor_shape(max_sum_shape));
// Init intermediate tensors
- _max = std::make_unique<TensorInfo>(max_info);
- _tmp = std::make_unique<TensorInfo>(tensor_info_tmp);
+ _max = TensorInfo(max_info);
+ _tmp = TensorInfo(tensor_info_tmp);
// Configure kernels
auto mk = std::make_unique<kernels::CpuLogits1DMaxKernel>();
- mk->configure(tmp_input, _max.get());
+ mk->configure(tmp_input, &_max);
_max_kernel = std::move(mk);
auto sm = std::make_unique<kernels::CpuLogits1DSoftmaxKernel<IS_LOG>>();
if(_needs_permute)
{
- _output_permuted = std::make_unique<TensorInfo>();
-
// The normalization kernel stores the result in a permuted output tensor
- sm->configure(tmp_input, _max.get(), _output_permuted.get(), beta, _tmp.get());
+ sm->configure(tmp_input, &_max, &_output_permuted, beta, &_tmp);
// Re-permute the permuted output into the requested (4D) output
- _permute_output.configure(_output_permuted.get(), dst, softmax_helpers::get_permutation_vector_from_softmax_axis(actual_axis));
+ _permute_output.configure(&_output_permuted, dst, softmax_helpers::get_permutation_vector_from_softmax_axis(actual_axis));
}
else
{
// Softmax 2D case
- sm->configure(tmp_input, _max.get(), dst, beta, _tmp.get());
+ sm->configure(tmp_input, &_max, dst, beta, &_tmp);
}
_softmax_kernel = std::move(sm);
+
+ _aux_mem[InternalTensorIdx::MAX] = MemoryInfo(offset_int_vec(InternalTensorIdx::MAX), MemoryLifetime::Temporary, _max.total_size());
+ _aux_mem[InternalTensorIdx::TMP] = MemoryInfo(offset_int_vec(InternalTensorIdx::TMP), MemoryLifetime::Temporary, _tmp.total_size());
+
+ _aux_mem[InternalTensorIdx::PERMUTED_SRC] = MemoryInfo(offset_int_vec(InternalTensorIdx::PERMUTED_SRC), MemoryLifetime::Temporary, _input_permuted.total_size());
+ _aux_mem[InternalTensorIdx::PERMUTED_DST] = MemoryInfo(offset_int_vec(InternalTensorIdx::PERMUTED_DST), MemoryLifetime::Temporary, _output_permuted.total_size());
}
template <bool IS_LOG>
@@ -141,42 +157,54 @@ void CpuSoftmaxGeneric<IS_LOG>::run(ITensorPack &tensors)
{
ARM_COMPUTE_ERROR_ON_MSG(tensors.empty(), "No inputs provided");
+ auto src = tensors.get_const_tensor(TensorType::ACL_SRC);
+ auto dst = tensors.get_tensor(TensorType::ACL_DST);
+
+ CpuAuxTensorHandler tmp(offset_int_vec(InternalTensorIdx::TMP), _tmp, tensors, false);
+ CpuAuxTensorHandler max(offset_int_vec(InternalTensorIdx::MAX), _max, tensors, false);
+
+ CpuAuxTensorHandler input_permuted(offset_int_vec(InternalTensorIdx::PERMUTED_SRC), _input_permuted, tensors, false);
+ CpuAuxTensorHandler output_permuted(offset_int_vec(InternalTensorIdx::PERMUTED_DST), _output_permuted, tensors, false);
+
ITensorPack max_pack;
ITensorPack softmax_pack;
if(_needs_permute)
{
- ITensorPack permute_in_pack;
- permute_in_pack.add_tensor(TensorType::ACL_SRC, tensors.get_const_tensor(ACL_SRC));
- permute_in_pack.add_tensor(TensorType::ACL_DST, tensors.get_tensor(ACL_INT_2));
+ ITensorPack permute_in_pack = { { TensorType::ACL_SRC, src }, { TensorType::ACL_DST, input_permuted.get() } };
_permute_input.run(permute_in_pack);
- max_pack.add_tensor(TensorType::ACL_SRC, tensors.get_tensor(ACL_INT_2));
+ max_pack = { { TensorType::ACL_SRC, input_permuted.get() }, { TensorType::ACL_DST, max.get() } };
- softmax_pack.add_tensor(TensorType::ACL_SRC_0, tensors.get_tensor(ACL_INT_2));
- softmax_pack.add_tensor(TensorType::ACL_SRC_1, tensors.get_tensor(ACL_INT_1));
- softmax_pack.add_tensor(TensorType::ACL_DST_0, tensors.get_tensor(ACL_INT_3));
- softmax_pack.add_tensor(TensorType::ACL_DST_1, tensors.get_tensor(ACL_INT_0));
+ softmax_pack =
+ {
+ { TensorType::ACL_SRC_0, input_permuted.get() },
+ { TensorType::ACL_SRC_1, max.get() },
+ { TensorType::ACL_DST_0, output_permuted.get() },
+ { TensorType::ACL_DST_1, tmp.get() }
+ };
}
else
{
- max_pack.add_tensor(TensorType::ACL_SRC, tensors.get_const_tensor(ACL_SRC));
- softmax_pack.add_tensor(TensorType::ACL_SRC_0, tensors.get_const_tensor(ACL_SRC));
- softmax_pack.add_tensor(TensorType::ACL_SRC_1, tensors.get_tensor(ACL_INT_1));
- softmax_pack.add_tensor(TensorType::ACL_DST_0, tensors.get_tensor(ACL_DST));
- softmax_pack.add_tensor(TensorType::ACL_DST_1, tensors.get_tensor(ACL_INT_0));
+ max_pack = { { TensorType::ACL_SRC, src }, { TensorType::ACL_DST, max.get() } };
+
+ softmax_pack =
+ {
+ { TensorType::ACL_SRC_0, src },
+ { TensorType::ACL_SRC_1, max.get() },
+ { TensorType::ACL_DST_0, dst },
+ { TensorType::ACL_DST_1, tmp.get() }
+ };
}
- max_pack.add_tensor(TensorType::ACL_DST, tensors.get_tensor(ACL_INT_1));
-
NEScheduler::get().schedule_op(_max_kernel.get(), Window::DimY, _max_kernel->window(), max_pack);
NEScheduler::get().schedule_op(_softmax_kernel.get(), Window::DimY, _softmax_kernel->window(), softmax_pack);
if(_needs_permute)
{
ITensorPack permute_out_pack;
- permute_out_pack.add_tensor(TensorType::ACL_SRC, tensors.get_tensor(ACL_INT_3));
- permute_out_pack.add_tensor(TensorType::ACL_DST, tensors.get_tensor(ACL_DST));
+ permute_out_pack.add_tensor(TensorType::ACL_SRC, output_permuted.get());
+ permute_out_pack.add_tensor(TensorType::ACL_DST, dst);
_permute_output.run(permute_out_pack);
}
}
@@ -184,18 +212,7 @@ void CpuSoftmaxGeneric<IS_LOG>::run(ITensorPack &tensors)
template <bool IS_LOG>
experimental::MemoryRequirements CpuSoftmaxGeneric<IS_LOG>::workspace() const
{
- experimental::MemoryRequirements req{};
-
- req.push_back({ TensorType::ACL_INT_0, _tmp->total_size(), 0 });
- req.push_back({ TensorType::ACL_INT_1, _max->total_size(), 0 });
-
- if(_needs_permute)
- {
- req.push_back({ TensorType::ACL_INT_2, _input_permuted->total_size(), 0 });
- req.push_back({ TensorType::ACL_INT_3, _output_permuted->total_size(), 0 });
- }
-
- return req;
+ return _aux_mem;
}
template class CpuSoftmaxGeneric<false>;
diff --git a/src/runtime/cpu/operators/CpuSoftmax.h b/src/runtime/cpu/operators/CpuSoftmax.h
index 9f18e0e4c5..38817977b3 100644
--- a/src/runtime/cpu/operators/CpuSoftmax.h
+++ b/src/runtime/cpu/operators/CpuSoftmax.h
@@ -24,7 +24,7 @@
#ifndef ARM_COMPUTE_CPU_SOFTMAX_H
#define ARM_COMPUTE_CPU_SOFTMAX_H
-#include "arm_compute/core/ITensorInfo.h"
+#include "arm_compute/core/TensorInfo.h"
#include "arm_compute/core/experimental/Types.h"
#include "src/core/cpu/ICpuKernel.h"
#include "src/runtime/cpu/ICpuOperator.h"
@@ -87,15 +87,27 @@ public:
experimental::MemoryRequirements workspace() const override;
private:
- CpuPermute _permute_input;
- CpuPermute _permute_output;
- std::unique_ptr<ICpuKernel> _max_kernel;
- std::unique_ptr<ICpuKernel> _softmax_kernel;
- std::unique_ptr<ITensorInfo> _max;
- std::unique_ptr<ITensorInfo> _tmp;
- std::unique_ptr<ITensorInfo> _input_permuted;
- std::unique_ptr<ITensorInfo> _output_permuted;
- bool _needs_permute;
+ enum InternalTensorIdx
+ {
+ MAX = 0,
+ TMP,
+ PERMUTED_SRC,
+ PERMUTED_DST,
+ COUNT
+ };
+
+ CpuPermute _permute_input;
+ CpuPermute _permute_output;
+ std::unique_ptr<ICpuKernel> _max_kernel;
+ std::unique_ptr<ICpuKernel> _softmax_kernel;
+
+ TensorInfo _max;
+ TensorInfo _tmp;
+ TensorInfo _input_permuted;
+ TensorInfo _output_permuted;
+
+ bool _needs_permute;
+ experimental::MemoryRequirements _aux_mem{};
};
using CpuSoftmax = CpuSoftmaxGeneric<false>;
using CpuLogSoftmax = CpuSoftmaxGeneric<true>;
diff --git a/src/runtime/cpu/utils/CpuAuxTensorHandler.h b/src/runtime/cpu/utils/CpuAuxTensorHandler.h
new file mode 100644
index 0000000000..644018a718
--- /dev/null
+++ b/src/runtime/cpu/utils/CpuAuxTensorHandler.h
@@ -0,0 +1,101 @@
+/*
+ * Copyright (c) 2021 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#ifndef ARM_COMPUTE_CPU_UTILS_CPU_AUX_TENSOR_HANDLER_H
+#define ARM_COMPUTE_CPU_UTILS_CPU_AUX_TENSOR_HANDLER_H
+
+#include "arm_compute/core/ITensorPack.h"
+#include "arm_compute/core/TensorInfo.h"
+#include "arm_compute/runtime/Tensor.h"
+
+#include "support/Cast.h"
+
+namespace arm_compute
+{
+namespace cpu
+{
+/* Tensor handler to wrap and handle tensor allocations on workspace buffers */
+class CpuAuxTensorHandler
+{
+public:
+ CpuAuxTensorHandler(int slot_id, TensorInfo &info, ITensorPack &pack, bool pack_inject = false)
+ : _tensor()
+ {
+ _tensor.allocator()->soft_init(info);
+
+ ITensor *packed_tensor = utils::cast::polymorphic_downcast<ITensor *>(pack.get_tensor(slot_id));
+ if((packed_tensor == nullptr) || (info.total_size() > packed_tensor->info()->total_size()))
+ {
+ _tensor.allocator()->allocate();
+ if(pack_inject)
+ {
+ pack.add_tensor(slot_id, &_tensor);
+ _injected_tensor_pack = &pack;
+ _injected_slot_id = slot_id;
+ }
+ }
+ else
+ {
+ _tensor.allocator()->import_memory(packed_tensor->buffer());
+ }
+ }
+
+ CpuAuxTensorHandler(TensorInfo &info, ITensor &tensor)
+ : _tensor()
+ {
+ _tensor.allocator()->soft_init(info);
+ if(info.total_size() <= tensor.info()->total_size())
+ {
+ _tensor.allocator()->import_memory(tensor.buffer());
+ }
+ }
+
+ CpuAuxTensorHandler(const CpuAuxTensorHandler &) = delete;
+ CpuAuxTensorHandler &operator=(const CpuAuxTensorHandler) = delete;
+
+ ~CpuAuxTensorHandler()
+ {
+ if(_injected_tensor_pack)
+ {
+ _injected_tensor_pack->remove_tensor(_injected_slot_id);
+ }
+ }
+
+ ITensor *get()
+ {
+ return &_tensor;
+ }
+
+ ITensor *operator()()
+ {
+ return &_tensor;
+ }
+
+private:
+ Tensor _tensor{};
+ ITensorPack *_injected_tensor_pack{ nullptr };
+ int _injected_slot_id{ TensorType::ACL_UNKNOWN };
+};
+} // namespace cpu
+} // namespace arm_compute
+#endif /* ARM_COMPUTE_CPU_UTILS_CPU_AUX_TENSOR_HANDLER_H */ \ No newline at end of file
diff --git a/src/runtime/gpu/cl/operators/ClSoftmax.cpp b/src/runtime/gpu/cl/operators/ClSoftmax.cpp
index c3ec7cc0da..975bb0b932 100644
--- a/src/runtime/gpu/cl/operators/ClSoftmax.cpp
+++ b/src/runtime/gpu/cl/operators/ClSoftmax.cpp
@@ -24,82 +24,30 @@
#include "src/runtime/gpu/cl/operators/ClSoftmax.h"
#include "arm_compute/core/utils/misc/ShapeCalculator.h"
#include "src/core/gpu/cl/kernels/ClSoftmaxKernel.h"
+#include "src/core/helpers/MemoryHelpers.h"
#include "src/core/helpers/SoftmaxHelpers.h"
#include "src/runtime/gpu/cl/operators/ClPermute.h"
+#include "src/runtime/gpu/cl/utils/ClAuxTensorHandler.h"
#include "support/Cast.h"
+using namespace arm_compute::experimental;
+
namespace arm_compute
{
namespace opencl
{
-namespace
-{
-void run_permute(ClPermute *op, const ITensor *src, ITensor *dst)
-{
- ARM_COMPUTE_ERROR_ON_NULLPTR(src, dst, op);
- ITensorPack pack;
- pack.add_const_tensor(TensorType::ACL_SRC, src);
- pack.add_tensor(TensorType::ACL_DST, dst);
- op->run(pack);
-}
-} // namespace
-
ClSoftmax::ClSoftmax()
: _permute_input(std::make_unique<ClPermute>()),
_permute_output(std::make_unique<ClPermute>()),
_max_shift_exp_sum_kernel(std::make_unique<kernels::ClLogits1DMaxShiftExpSumKernel>()),
_norm_kernel(std::make_unique<kernels::ClLogits1DNormKernel>()),
- _max_info(_internal_info[static_cast<uint32_t>(InternalTensorIdx::MAX)]),
- _sum_info(_internal_info[static_cast<uint32_t>(InternalTensorIdx::SUM)]),
- _tmp_info(_internal_info[static_cast<uint32_t>(InternalTensorIdx::TMP)]),
- _permuted_src_info(_internal_info[static_cast<uint32_t>(InternalTensorIdx::PERMUTED_SRC)]),
- _permuted_dst_info(_internal_info[static_cast<uint32_t>(InternalTensorIdx::PERMUTED_DST)])
-{
-}
-
-TensorType ClSoftmax::convert_internal_idx_to_tensor_type(InternalTensorIdx idx) const
-{
- switch(idx)
- {
- case InternalTensorIdx::MAX:
- return TensorType::ACL_INT_0;
- case InternalTensorIdx::SUM:
- return TensorType::ACL_INT_1;
- case InternalTensorIdx::TMP:
- return TensorType::ACL_INT_2;
- case InternalTensorIdx::PERMUTED_SRC:
- return TensorType::ACL_INT_3;
- case InternalTensorIdx::PERMUTED_DST:
- return TensorType::ACL_INT_4;
- default:
- ARM_COMPUTE_ERROR("invalid internal tensor index is given.");
- break;
- };
- return TensorType::ACL_UNKNOWN;
-}
-
-void ClSoftmax::create_internal_tensor(TensorInfo &info, InternalTensorIdx idx)
-{
- const auto tensor_idx = static_cast<uint32_t>(idx);
- if(!_internal_tensor[tensor_idx])
- {
- _internal_tensor[tensor_idx] = std::make_unique<CLTensor>();
- }
- _internal_tensor[tensor_idx]->allocator()->init(info);
-}
-
-void ClSoftmax::create_internal_tensor()
+ _max_info(),
+ _sum_info(),
+ _tmp_info(),
+ _permuted_src_info(),
+ _permuted_dst_info(),
+ _aux_mem(InternalTensorIdx::COUNT)
{
- for(uint32_t i = 0; i < static_cast<uint32_t>(InternalTensorIdx::COUNT); i++)
- {
- const auto tensor_idx = static_cast<InternalTensorIdx>(i);
-
- if(!_needs_permute && (tensor_idx == InternalTensorIdx::PERMUTED_DST || tensor_idx == InternalTensorIdx::PERMUTED_SRC))
- {
- continue;
- }
- create_internal_tensor(_internal_info[i], static_cast<InternalTensorIdx>(i));
- }
}
void ClSoftmax::configure(const CLCompileContext &compile_context, const ITensorInfo &src, ITensorInfo &dst, const SoftmaxKernelInfo &info)
@@ -137,6 +85,13 @@ void ClSoftmax::configure(const CLCompileContext &compile_context, const ITensor
const auto perm_info = softmax_helpers::get_permutation_vector_from_softmax_axis(actual_axis);
_permute_output->configure(compile_context, &_permuted_dst_info, &dst, perm_info);
}
+
+ _aux_mem[InternalTensorIdx::SUM] = MemoryInfo(offset_int_vec(InternalTensorIdx::SUM), MemoryLifetime::Temporary, _sum_info.total_size());
+ _aux_mem[InternalTensorIdx::TMP] = MemoryInfo(offset_int_vec(InternalTensorIdx::TMP), MemoryLifetime::Temporary, _tmp_info.total_size());
+ _aux_mem[InternalTensorIdx::MAX] = MemoryInfo(offset_int_vec(InternalTensorIdx::MAX), MemoryLifetime::Temporary, _max_info.total_size());
+
+ _aux_mem[InternalTensorIdx::PERMUTED_SRC] = MemoryInfo(offset_int_vec(InternalTensorIdx::PERMUTED_SRC), MemoryLifetime::Temporary, _permuted_src_info.total_size());
+ _aux_mem[InternalTensorIdx::PERMUTED_DST] = MemoryInfo(offset_int_vec(InternalTensorIdx::PERMUTED_DST), MemoryLifetime::Temporary, _permuted_dst_info.total_size());
}
Status ClSoftmax::validate(const ITensorInfo &src, const ITensorInfo &dst, const SoftmaxKernelInfo &info)
@@ -172,105 +127,60 @@ Status ClSoftmax::validate(const ITensorInfo &src, const ITensorInfo &dst, const
return Status{};
}
-void ClSoftmax::import_workspace_memory(ITensorPack &tensors)
+void ClSoftmax::run(ITensorPack &tensors)
{
- auto import_workspace_memory = [this, &tensors](InternalTensorIdx idx)
- {
- const auto workspace_idx = convert_internal_idx_to_tensor_type(idx);
- auto imported_tensor = tensors.get_tensor(workspace_idx);
- if(imported_tensor)
- {
- auto imported_memory = utils::cast::polymorphic_downcast<ICLTensor *>(imported_tensor)->cl_buffer();
- _internal_tensor[static_cast<uint32_t>(idx)].get()->allocator()->import_memory(imported_memory);
- }
- };
+ auto src = tensors.get_const_tensor(TensorType::ACL_SRC);
+ auto dst = tensors.get_tensor(TensorType::ACL_DST);
- import_workspace_memory(InternalTensorIdx::PERMUTED_SRC);
- import_workspace_memory(InternalTensorIdx::PERMUTED_DST);
- import_workspace_memory(InternalTensorIdx::MAX);
- import_workspace_memory(InternalTensorIdx::SUM);
- import_workspace_memory(InternalTensorIdx::TMP);
-}
+ CLAuxTensorHandler sum(offset_int_vec(InternalTensorIdx::SUM), _sum_info, tensors, false);
+ CLAuxTensorHandler tmp(offset_int_vec(InternalTensorIdx::TMP), _tmp_info, tensors, false);
+ CLAuxTensorHandler max(offset_int_vec(InternalTensorIdx::MAX), _max_info, tensors, false);
+
+ CLAuxTensorHandler permuted_src(offset_int_vec(InternalTensorIdx::PERMUTED_SRC), _permuted_src_info, tensors, false);
+ CLAuxTensorHandler permuted_dst(offset_int_vec(InternalTensorIdx::PERMUTED_DST), _permuted_dst_info, tensors, false);
-void ClSoftmax::run_source_permute(const ITensor *src)
-{
if(_needs_permute)
{
- auto permuted_src = _internal_tensor[static_cast<uint32_t>(InternalTensorIdx::PERMUTED_SRC)].get();
- run_permute(_permute_input.get(), src, permuted_src);
+ ITensorPack pack;
+ pack.add_const_tensor(TensorType::ACL_SRC, src);
+ pack.add_tensor(TensorType::ACL_DST, permuted_src.get());
+ _permute_input.get()->run(pack);
}
-}
-void ClSoftmax::run_destination_permute(ITensor *dst)
-{
+ ITensorPack sum_pack;
+ ITensorPack norm_pack;
if(_needs_permute)
{
- auto permuted_dst = _internal_tensor[static_cast<uint32_t>(InternalTensorIdx::PERMUTED_DST)].get();
- run_permute(_permute_output.get(), permuted_dst, dst);
+ sum_pack.add_const_tensor(TensorType::ACL_SRC, permuted_src.get());
+ norm_pack.add_tensor(TensorType::ACL_DST, permuted_dst.get());
}
-}
-
-void ClSoftmax::run_max_sum(const ITensor *src)
-{
- auto max = _internal_tensor[static_cast<uint32_t>(InternalTensorIdx::MAX)].get();
- auto sum = _internal_tensor[static_cast<uint32_t>(InternalTensorIdx::SUM)].get();
- auto tmp = _internal_tensor[static_cast<uint32_t>(InternalTensorIdx::TMP)].get();
-
- ARM_COMPUTE_ERROR_ON_NULLPTR(src, tmp, max, sum);
+ else
+ {
+ sum_pack.add_const_tensor(TensorType::ACL_SRC, src);
+ norm_pack.add_tensor(TensorType::ACL_DST, dst);
+ }
+ sum_pack.add_tensor(TensorType::ACL_DST, tmp.get());
+ sum_pack.add_tensor(TensorType::ACL_INT_0, max.get());
+ sum_pack.add_tensor(TensorType::ACL_INT_1, sum.get());
- ITensorPack sum_pack;
- sum_pack.add_const_tensor(TensorType::ACL_SRC, src);
- sum_pack.add_tensor(TensorType::ACL_DST, tmp);
- sum_pack.add_tensor(TensorType::ACL_INT_0, max);
- sum_pack.add_tensor(TensorType::ACL_INT_1, sum);
+ norm_pack.add_const_tensor(TensorType::ACL_SRC, tmp.get());
+ norm_pack.add_tensor(TensorType::ACL_INT_0, sum.get());
CLScheduler::get().enqueue_op(*_max_shift_exp_sum_kernel.get(), sum_pack, false);
-}
-
-void ClSoftmax::run_norm(ITensor *dst)
-{
- auto sum = _internal_tensor[static_cast<uint32_t>(InternalTensorIdx::SUM)].get();
- auto tmp = _internal_tensor[static_cast<uint32_t>(InternalTensorIdx::TMP)].get();
-
- ARM_COMPUTE_ERROR_ON_NULLPTR(tmp, sum, dst);
-
- ITensorPack norm_pack;
- norm_pack.add_const_tensor(TensorType::ACL_SRC, tmp);
- norm_pack.add_tensor(TensorType::ACL_DST, dst);
- norm_pack.add_tensor(TensorType::ACL_INT_0, sum);
-
CLScheduler::get().enqueue_op(*_norm_kernel.get(), norm_pack, false);
-}
-
-void ClSoftmax::run(ITensorPack &tensors)
-{
- create_internal_tensor();
-
- auto src = tensors.get_const_tensor(TensorType::ACL_SRC);
- auto dst = tensors.get_tensor(TensorType::ACL_DST);
-
- import_workspace_memory(tensors);
- run_source_permute(src);
- run_max_sum(!_needs_permute ? src : _internal_tensor[static_cast<uint32_t>(InternalTensorIdx::PERMUTED_SRC)].get());
- run_norm(!_needs_permute ? dst : _internal_tensor[static_cast<uint32_t>(InternalTensorIdx::PERMUTED_DST)].get());
- run_destination_permute(dst);
-}
-
-experimental::MemoryRequirements ClSoftmax::workspace() const
-{
- experimental::MemoryRequirements req{};
-
- req.emplace_back(convert_internal_idx_to_tensor_type(InternalTensorIdx::SUM), _sum_info.total_size(), 0);
- req.emplace_back(convert_internal_idx_to_tensor_type(InternalTensorIdx::TMP), _tmp_info.total_size(), 0);
- req.emplace_back(convert_internal_idx_to_tensor_type(InternalTensorIdx::MAX), _max_info.total_size(), 0);
if(_needs_permute)
{
- req.emplace_back(convert_internal_idx_to_tensor_type(InternalTensorIdx::PERMUTED_SRC), _permuted_src_info.total_size(), 0);
- req.emplace_back(convert_internal_idx_to_tensor_type(InternalTensorIdx::PERMUTED_DST), _permuted_dst_info.total_size(), 0);
+ ITensorPack pack;
+ pack.add_const_tensor(TensorType::ACL_SRC, permuted_dst.get());
+ pack.add_tensor(TensorType::ACL_DST, dst);
+ _permute_output.get()->run(pack);
}
+}
- return req;
+experimental::MemoryRequirements ClSoftmax::workspace() const
+{
+ return _aux_mem;
}
} // namespace opencl
} // namespace arm_compute \ No newline at end of file
diff --git a/src/runtime/gpu/cl/operators/ClSoftmax.h b/src/runtime/gpu/cl/operators/ClSoftmax.h
index e38b7c595a..f19a51fc5e 100644
--- a/src/runtime/gpu/cl/operators/ClSoftmax.h
+++ b/src/runtime/gpu/cl/operators/ClSoftmax.h
@@ -67,7 +67,7 @@ public:
experimental::MemoryRequirements workspace() const override;
private:
- enum class InternalTensorIdx
+ enum InternalTensorIdx
{
MAX = 0,
SUM,
@@ -77,41 +77,19 @@ private:
COUNT
};
- /** Create a single internal tensor
- *
- * @param[in] info The information used to create a tensor
- * @param[in] idx The index within the internal array the created tensor will be held
- */
- void create_internal_tensor(TensorInfo &info, InternalTensorIdx idx);
- /** Create all required internal tensors */
- void create_internal_tensor();
- /** Function to convert from internal tensor index to @ref TensorType used externally */
- TensorType convert_internal_idx_to_tensor_type(InternalTensorIdx idx) const;
- /** Function to import workspace memory allocated by the caller into internal tensor instances */
- void import_workspace_memory(ITensorPack &tensors);
- /** Function to permute the given source tensor when permutation is required */
- void run_source_permute(const ITensor *src);
- /** Function to permute the intemediate tensor to the final destination tensor when permutation is required */
- void run_destination_permute(ITensor *dst);
- /** Function to run @ref arm_compute::opencl::kernels::ClLogits1DMaxShiftExpSumKernel */
- void run_max_sum(const ITensor *src);
- /** Function to run @ref kernels::ClLogits1DNormKernel */
- void run_norm(ITensor *dst);
-
std::unique_ptr<ClPermute> _permute_input;
std::unique_ptr<ClPermute> _permute_output;
std::unique_ptr<kernels::ClLogits1DMaxShiftExpSumKernel> _max_shift_exp_sum_kernel;
std::unique_ptr<kernels::ClLogits1DNormKernel> _norm_kernel;
bool _needs_permute{ false };
- std::array<TensorInfo, static_cast<uint32_t>(InternalTensorIdx::COUNT)> _internal_info{};
- std::array<std::unique_ptr<CLTensor>, static_cast<uint32_t>(InternalTensorIdx::COUNT)> _internal_tensor{};
+ TensorInfo _max_info;
+ TensorInfo _sum_info;
+ TensorInfo _tmp_info;
+ TensorInfo _permuted_src_info;
+ TensorInfo _permuted_dst_info;
- TensorInfo &_max_info;
- TensorInfo &_sum_info;
- TensorInfo &_tmp_info;
- TensorInfo &_permuted_src_info;
- TensorInfo &_permuted_dst_info;
+ experimental::MemoryRequirements _aux_mem{};
};
} // opencl