aboutsummaryrefslogtreecommitdiff
path: root/src/runtime/cpu
diff options
context:
space:
mode:
authorMichele Di Giorgio <michele.digiorgio@arm.com>2021-06-16 11:14:41 +0100
committerMichele Di Giorgio <michele.digiorgio@arm.com>2021-06-25 13:52:38 +0000
commitd7316eb877cc4ff8573219374335e917b19a0203 (patch)
tree9918f85a12424ccd53ae91f4d7b7701b6e0747a9 /src/runtime/cpu
parentcd060c47c1bad06f2aad8f0f8f94a72c4f75b919 (diff)
downloadComputeLibrary-d7316eb877cc4ff8573219374335e917b19a0203.tar.gz
Port NEGEMMConv2d to memory injecting interface
Resolves: COMPMID-4506, COMPMID-4570 Change-Id: I6d37a06da141f1fcfcaa8525322a319cb0234791 Signed-off-by: Michele Di Giorgio <michele.digiorgio@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/5824 Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/runtime/cpu')
-rw-r--r--src/runtime/cpu/operators/CpuGemmDirectConv2d.cpp77
-rw-r--r--src/runtime/cpu/operators/CpuGemmDirectConv2d.h40
-rw-r--r--src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp265
-rw-r--r--src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.h20
4 files changed, 145 insertions, 257 deletions
diff --git a/src/runtime/cpu/operators/CpuGemmDirectConv2d.cpp b/src/runtime/cpu/operators/CpuGemmDirectConv2d.cpp
index e50099df1f..c2e9f24ff6 100644
--- a/src/runtime/cpu/operators/CpuGemmDirectConv2d.cpp
+++ b/src/runtime/cpu/operators/CpuGemmDirectConv2d.cpp
@@ -26,10 +26,10 @@
#include "arm_compute/core/utils/misc/ShapeCalculator.h"
#include "arm_compute/core/utils/quantization/AsymmHelpers.h"
#include "arm_compute/runtime/FunctionDescriptors.h"
-#include "arm_compute/runtime/NEON/NEScheduler.h"
-#include "src/runtime/cpu/operators/CpuActivation.h"
-#include "src/runtime/cpu/operators/CpuPermute.h"
-#include "src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.h"
+#include "src/core/helpers/MemoryHelpers.h"
+#include "src/runtime/cpu/utils/CpuAuxTensorHandler.h"
+
+#include "support/Cast.h"
#include <set>
@@ -37,6 +37,9 @@ namespace arm_compute
{
namespace cpu
{
+using namespace arm_compute::experimental;
+using namespace arm_compute::utils::cast;
+
namespace
{
GEMMLowpOutputStageInfo calculate_output_stage_metadata(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *dst, const ActivationLayerInfo &act)
@@ -87,12 +90,14 @@ cpu::AsmGemmInfo init_assembly_metadata(const Conv2dInfo &info, bool is_indirect
}
} // namespace
-CpuGemmDirectConv2d::CpuGemmDirectConv2d(const std::shared_ptr<IMemoryManager> &memory_manager)
- : _gemm_asm_func(std::make_unique<CpuGemmAssemblyDispatch>(memory_manager)),
+CpuGemmDirectConv2d::CpuGemmDirectConv2d()
+ : _gemm_asm_func(std::make_unique<CpuGemmAssemblyDispatch>()),
_activation_func(std::make_unique<CpuActivation>()),
_weights_permute_func(std::make_unique<CpuPermute>()),
- _permuted_weights_info(),
- _permuted_weights(std::make_unique<Tensor>())
+ _aux_mem(AuxTensorIdx::Count),
+ _perm_weights(),
+ _run_activation(false),
+ _is_prepared(false)
{
}
@@ -106,8 +111,10 @@ void CpuGemmDirectConv2d::configure(const ITensorInfo *src, const ITensorInfo *w
biases != nullptr ? biases : nullptr,
dst,
info));
- _original_weights_info = weights;
- _weights_permute_func->configure(weights, &_permuted_weights_info, PermutationVector{ 3, 0, 1, 2 });
+ _run_activation = info.act_info.enabled() && !_gemm_asm_func->is_activation_supported(info.act_info);
+ _is_prepared = false;
+
+ _weights_permute_func->configure(weights, &_perm_weights, PermutationVector{ 3, 0, 1, 2 });
// Configure assembly dispatch
cpu::AsmGemmInfo asm_info = init_assembly_metadata(info, false);
@@ -115,13 +122,27 @@ void CpuGemmDirectConv2d::configure(const ITensorInfo *src, const ITensorInfo *w
{
asm_info.output_stage = calculate_output_stage_metadata(src, weights, dst, info.act_info);
}
- _gemm_asm_func->configure(src, &_permuted_weights_info, biases, dst, asm_info);
+ _gemm_asm_func->configure(src, &_perm_weights, biases, dst, asm_info);
// Configure activation
- if(info.act_info.enabled() && !_gemm_asm_func->is_activation_supported(info.act_info))
+ if(_run_activation)
{
_activation_func->configure(dst, nullptr, info.act_info);
- _run_activation = true;
+ }
+
+ // Add auxiliary memory requirements of the assembly dispatch
+ auto asm_mem_req = _gemm_asm_func->workspace();
+ _aux_mem[AsmGemmWorkspace] = asm_mem_req[AsmGemmWorkspace];
+ _aux_mem[Pretranspose] = asm_mem_req[Pretranspose];
+
+ if(_aux_mem[Pretranspose].size > 0)
+ {
+ // Release permuted weights at the of prepare as they are further transposed by the assembly dispatch
+ _aux_mem[PermutedWeights] = MemoryInfo(offset_int_vec(PermutedWeights), MemoryLifetime::Prepare, weights->total_size());
+ }
+ else
+ {
+ _aux_mem[PermutedWeights] = MemoryInfo(offset_int_vec(PermutedWeights), MemoryLifetime::Persistent, weights->total_size());
}
}
Status CpuGemmDirectConv2d::validate(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, const Conv2dInfo &info)
@@ -172,35 +193,29 @@ void CpuGemmDirectConv2d::run(ITensorPack &tensors)
}
}
-void CpuGemmDirectConv2d::allocate_permuted_weights()
-{
- // TODO: This function will be removed when memory injection is implemeted.
- ARM_COMPUTE_ERROR_ON(_permuted_weights == nullptr);
- _permuted_weights->allocator()->free();
- _permuted_weights->allocator()->init(_permuted_weights_info);
- _permuted_weights->allocator()->allocate();
-}
-
void CpuGemmDirectConv2d::prepare(ITensorPack &tensors)
{
if(!_is_prepared)
{
- allocate_permuted_weights();
- ITensorPack permute_tensors
- {
- { TensorType::ACL_SRC, tensors.get_const_tensor(TensorType::ACL_SRC_1) },
- { TensorType::ACL_DST, _permuted_weights.get() },
- };
+ const ITensor *weights = tensors.get_const_tensor(ACL_SRC_1);
+ ITensor *weights_aux = utils::cast::polymorphic_cast<ITensor *>(tensors.get_tensor(offset_int_vec(PermutedWeights)));
+ ARM_COMPUTE_ERROR_ON_NULLPTR(weights, weights_aux);
+ CpuAuxTensorHandler permuted_weights(_perm_weights, *weights_aux);
+ ITensorPack permute_tensors{ { ACL_SRC, weights }, { ACL_DST, permuted_weights.get() } };
_weights_permute_func->run(permute_tensors);
- tensors.get_const_tensor(TensorType::ACL_SRC_1)->mark_as_unused();
+ tensors.add_const_tensor(ACL_SRC_1, permuted_weights.get());
+ // Call prepare of assembly dispatch
+ _gemm_asm_func->prepare(tensors);
- // switch the original tensor with permuted tensor
- tensors.add_const_tensor(TensorType::ACL_SRC_1, _permuted_weights.get());
_is_prepared = true;
}
}
+experimental::MemoryRequirements CpuGemmDirectConv2d::workspace() const
+{
+ return _aux_mem;
+}
} // namespace cpu
} // namespace arm_compute \ No newline at end of file
diff --git a/src/runtime/cpu/operators/CpuGemmDirectConv2d.h b/src/runtime/cpu/operators/CpuGemmDirectConv2d.h
index 6aa17c2349..b572f36a3a 100644
--- a/src/runtime/cpu/operators/CpuGemmDirectConv2d.h
+++ b/src/runtime/cpu/operators/CpuGemmDirectConv2d.h
@@ -24,14 +24,12 @@
#ifndef ARM_COMPUTE_CPU_GEMM_DIRECT_CONV_2D_H
#define ARM_COMPUTE_CPU_GEMM_DIRECT_CONV_2D_H
-#include "arm_compute/core/ITensorInfo.h"
-#include "arm_compute/core/experimental/Types.h"
-#include "arm_compute/runtime/Tensor.h"
+#include "arm_compute/core/TensorInfo.h"
#include "src/core/common/Macros.h"
-#include "src/core/cpu/ICpuKernel.h"
#include "src/runtime/cpu/ICpuOperator.h"
-
-#include <memory>
+#include "src/runtime/cpu/operators/CpuActivation.h"
+#include "src/runtime/cpu/operators/CpuPermute.h"
+#include "src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.h"
namespace arm_compute
{
@@ -40,15 +38,11 @@ class ITensor;
struct Conv2dInfo;
namespace cpu
{
-class CpuGemmAssemblyDispatch;
-class CpuActivation;
-class CpuPermute;
-
class CpuGemmDirectConv2d : public ICpuOperator
{
public:
/** Constructor */
- CpuGemmDirectConv2d(const std::shared_ptr<IMemoryManager> &memory_manager = nullptr);
+ CpuGemmDirectConv2d();
ARM_COMPUTE_DISALLOW_COPY_ALLOW_MOVE(CpuGemmDirectConv2d);
/** Destructor */
~CpuGemmDirectConv2d();
@@ -89,22 +83,24 @@ public:
// Inherited methods overridden:
void run(ITensorPack &tensors) override;
void prepare(ITensorPack &constants) override;
+ experimental::MemoryRequirements workspace() const override;
private:
+ enum AuxTensorIdx
+ {
+ AsmGemmWorkspace = 0,
+ Pretranspose,
+ PermutedWeights,
+ Count
+ };
+
std::unique_ptr<CpuGemmAssemblyDispatch> _gemm_asm_func;
std::unique_ptr<CpuActivation> _activation_func;
std::unique_ptr<CpuPermute> _weights_permute_func;
- const ITensorInfo *_original_weights_info{};
- TensorInfo _permuted_weights_info;
- std::unique_ptr<Tensor> _permuted_weights{ nullptr };
- bool _is_prepared{ false };
- bool _run_activation{ false };
-
- /** Function to allocated a tensor for permuted weights
- *
- * @note This function will be removed when memory injection is properly implemented.
- */
- void allocate_permuted_weights();
+ experimental::MemoryRequirements _aux_mem;
+ TensorInfo _perm_weights;
+ bool _run_activation;
+ bool _is_prepared;
};
} // namespace cpu
} // namespace arm_compute
diff --git a/src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp b/src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
index 1101e05a0d..79ea1cb5a7 100644
--- a/src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
+++ b/src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
@@ -27,15 +27,18 @@
#include "src/core/CPP/Validate.h"
#include "src/core/cpu/kernels/assembly/CpuGemmAssemblyWrapperKernel.h"
#include "src/core/cpu/kernels/assembly/arm_gemm.hpp"
+#include "src/core/helpers/MemoryHelpers.h"
#include "src/core/utils/AssemblyUtils.h"
+#include "src/runtime/cpu/utils/CpuAuxTensorHandler.h"
#include <arm_neon.h>
-#include <cstdlib>
namespace arm_compute
{
namespace cpu
{
+using namespace arm_compute::experimental;
+
namespace
{
struct free_delete
@@ -113,103 +116,27 @@ IScheduler::Hints scheduling_hint_heuristic(arm_gemm::GemmMethod method, DataTyp
return scheduling_hint;
}
-template <typename TypeInput, typename TypeOutput>
-class FallbackTransform : public ITransformWeights
-{
-public:
- FallbackTransform() noexcept {};
- /** Prevent instances of this class from being copied (As this class contains pointers) */
- FallbackTransform(const FallbackTransform &) = delete;
- /** Default move constructor */
- FallbackTransform(FallbackTransform &&) = default;
- /** Prevent instances of this class from being copied (As this class contains pointers) */
- FallbackTransform &operator=(const FallbackTransform &) = delete;
- /** Default move assignment operator */
- FallbackTransform &operator=(FallbackTransform &&) = default;
- void run() override
- {
- _output.allocator()->allocate();
- ARM_COMPUTE_ERROR_ON(_output.buffer() == nullptr);
- _gemm_kernel_asm->pretranspose_B_array(_output.buffer(), _in1_ptr, _ldb, _multi_stride_b);
- _reshape_run = true;
- }
-
- void release() override
- {
- _output.allocator()->free();
- }
-
- ITensor *get_weights() override
- {
- return &_output;
- }
-
- uint32_t uid() override
- {
- uint32_t id = (_B_pretranspose_size | 0x80000000);
- return id;
- }
-
- void configure(size_t B_pretranspose_size, unsigned int alignment)
- {
- _output.allocator()->init(TensorInfo(TensorShape{ (B_pretranspose_size + alignment) }, 1, DataType::S8), alignment);
- _B_pretranspose_size = B_pretranspose_size;
- }
-
- void set_pretranspose(ITensor *tensor)
- {
- if(!_reshape_run)
- {
- _gemm_kernel_asm->set_pretransposed_B_data(tensor->buffer());
- }
- }
-
- void set_args(const int ldb, const TypeInput *in1_ptr, const int multi_stride_b, std::shared_ptr<arm_gemm::GemmCommon<TypeInput, TypeOutput>> gemm_kernel_asm)
- {
- _ldb = ldb;
- _in1_ptr = in1_ptr;
- _multi_stride_b = multi_stride_b;
- _gemm_kernel_asm = gemm_kernel_asm;
- }
-
-private:
- Tensor _output{};
- int _ldb{};
- const TypeInput *_in1_ptr{};
- int _multi_stride_b{};
- size_t _B_pretranspose_size{};
- std::shared_ptr<arm_gemm::GemmCommon<TypeInput, TypeOutput>> _gemm_kernel_asm{ nullptr };
-};
-
/** Fallback in case ACL doesn't have a function */
template <typename TypeInput, typename TypeOutput, class OutputStage = arm_gemm::Nothing>
class Fallback : public CpuGemmAssemblyDispatch::IFallback
{
public:
/** Destructor */
- ~Fallback()
- {
- if(_pretranspose && !(is_weight_managed()))
- {
- delete _pretranspose;
- }
- }
+ ~Fallback() = default;
/** Initialise the functions's input and output.
*
- * @param[in] a Input tensor containing the Matrix A.
- * @param[in] b Input tensor containing the Matrix B.
- * @param[in] c Input tensor containing the Matrix C.
- * @param[out] d Output tensor to store the result of matrix multiplication.
- * @param[in] args Matrix multiplication information.
- * @param[in] gemm_info GEMM meta-data
- * @param[in] memory_group Memory group to be used by the function.
- * @param[in] weights_manager Weights manager to be used by the function.
- * @param[in] os Output stage meta-data.
+ * @param[in] a Input tensor containing the Matrix A.
+ * @param[in] b Input tensor containing the Matrix B.
+ * @param[in] c Input tensor containing the Matrix C.
+ * @param[out] d Output tensor to store the result of matrix multiplication.
+ * @param[in] args Matrix multiplication information.
+ * @param[in] gemm_info GEMM meta-data
+ * @param[in] os Output stage meta-data.
*/
void configure(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, ITensorInfo *d,
arm_gemm::GemmArgs args, const AsmGemmInfo &gemm_info,
- MemoryGroup &memory_group, IWeightsManager *weights_manager, const OutputStage &os = {});
+ const OutputStage &os = {});
/** Set requantization shifts to be used
*
@@ -231,16 +158,17 @@ public:
// Inherited methods overridden:
void run(ITensorPack &tensors) override;
void prepare(ITensorPack &tensors) override;
- bool is_configured() const override;
+ bool is_configured() const override;
+ experimental::MemoryRequirements workspace() const override;
private:
- /** Allocate a workspace tensor.
- *
- * @param[in] workspace_size Size to allocate.
- * @param[in] memory_group Tensor memory group.
- * @param[in] alignment Workspace memory alignment.
- */
- void allocate_workspace(size_t workspace_size, MemoryGroup &memory_group, size_t alignment);
+ enum AuxTensorIdx
+ {
+ AsmGemmWorkspace = 0,
+ Pretranspose,
+ Count
+ };
+
/** Configure the indirect buffer
*
* @param[in] a Input tensor containing the Matrix A.
@@ -256,18 +184,14 @@ private:
std::shared_ptr<arm_gemm::GemmCommon<TypeInput, TypeOutput>> _gemm_kernel_asm{ nullptr };
/** Optimised Arm® Neon™ kernel */
std::unique_ptr<INEKernel> _optimised_kernel{ nullptr };
- /** GEMM workspace */
- Tensor _workspace{};
- /** Pre-transpose tensor */
- ITensor *_pretranspose{ nullptr };
+ /** Assembly GEMM workspace tensor info */
+ TensorInfo _workspace_info{};
+ /** Pre-transpose tensor info */
+ TensorInfo _pretranspose_info{};
/** Prepared flag */
bool _is_prepared{ false };
/** GEMM meta-data */
AsmGemmInfo _gemm_info{};
- /** Weights manager */
- IWeightsManager *_weights_manager{ nullptr };
- /** Weights transform object */
- FallbackTransform<TypeInput, TypeOutput> _weights_transform{};
/** GEMM kernel description */
arm_gemm::KernelDescription _kernel_info{};
/** Per channel quantization shifts */
@@ -279,27 +203,9 @@ private:
/** Indirect buffer */
std::unique_ptr<const TypeInput *const *, free_delete> _indirect_arg{};
std::unique_ptr<const TypeInput *, free_delete> _indirect_buf{};
- std::vector<TypeInput> _indirect_pad{};
- arm_gemm::ConvolutionParameters _cp{};
-
- bool is_weight_managed()
- {
- // TODO (COMPMID-4539): This function should do the following:
- // _weights_manager && _weights_manager->are_weights_managed(_b)
- // , where _b is the second Tensor that is used to be given to the configure().
- // Currently, however, weight manager is disabled to make this class stateless.
- // This should be revisited in the future.
- return false;
- }
-
- void acquire_managed_weight()
- {
- // TODO (COMPMID-4539): This function should do the following:
- // _pretranspose = _weights_manager->acquire(_b, &_weights_transform);
- // , where _b is the second Tensor that is used to be given to the configure().
- // Currently, however, weight manager is disabled to make this class stateless.
- _pretranspose = nullptr;
- }
+ std::vector<TypeInput> _indirect_pad{};
+ arm_gemm::ConvolutionParameters _cp{};
+ experimental::MemoryRequirements _aux_mem{ Count };
};
template <typename TypeInput, typename TypeOutput, class OutputStage>
@@ -439,12 +345,11 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::configure_indirect(const ITen
template <typename TypeInput, typename TypeOutput, class OutputStage>
void Fallback<TypeInput, TypeOutput, OutputStage>::configure(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, ITensorInfo *d,
arm_gemm::GemmArgs args, const AsmGemmInfo &gemm_info,
- MemoryGroup &memory_group, IWeightsManager *weights_manager, const OutputStage &os)
+ const OutputStage &os)
{
ARM_COMPUTE_UNUSED(c);
arm_gemm::GemmConfig gemm_cfg;
- _kernel_info = arm_gemm::get_gemm_method<TypeInput, TypeOutput, OutputStage>(args, os);
- _weights_manager = weights_manager;
+ _kernel_info = arm_gemm::get_gemm_method<TypeInput, TypeOutput, OutputStage>(args, os);
if(_kernel_info.method != arm_gemm::GemmMethod::GEMV_BATCHED)
{
gemm_cfg.filter = _kernel_info.name;
@@ -461,13 +366,10 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::configure(const ITensorInfo *
auto acl_gemm_wrapper = std::make_unique<kernel::CpuGemmAssemblyWrapperKernel<TypeInput, TypeOutput>>();
ARM_COMPUTE_ERROR_ON(acl_gemm_wrapper == nullptr);
acl_gemm_wrapper->configure(_gemm_kernel_asm.get(), gemm_cfg.filter);
- const size_t workspace_size = _gemm_kernel_asm->get_working_size();
- if(workspace_size > 0)
- {
- // Allocate workspace
- const unsigned int alignment = 4096;
- allocate_workspace(workspace_size, memory_group, alignment);
- }
+ const size_t workspace_size = _gemm_kernel_asm->get_working_size();
+ const unsigned int alignment = 4096;
+ _workspace_info = TensorInfo(TensorShape(workspace_size), 1, DataType::U8);
+ _aux_mem[AsmGemmWorkspace] = MemoryInfo(offset_int_vec(AsmGemmWorkspace), MemoryLifetime::Temporary, workspace_size, alignment);
//if we disable this code below in brackets then ConvLayer deadlocks when threads > 1 and
//the shapes are In=1x1x1024 Weights=1x1x1024x1001 Biases=1001 Out=1x1x1001
@@ -487,16 +389,8 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::configure(const ITensorInfo *
// Forcing 128-byte alignment (required by 32-bit kernels)
const unsigned int alignment = 128;
const size_t B_pretranspose_size = _gemm_kernel_asm->get_B_pretransposed_array_size();
- if(is_weight_managed())
- {
- _weights_transform.configure(B_pretranspose_size, alignment);
- acquire_managed_weight();
- }
- else
- {
- _pretranspose = new Tensor();
- static_cast<Tensor *>(_pretranspose)->allocator()->init(TensorInfo(TensorShape{ (B_pretranspose_size + alignment) }, 1, DataType::S8), alignment);
- }
+ _pretranspose_info = TensorInfo(TensorShape(B_pretranspose_size), 1, DataType::U8);
+ _aux_mem[Pretranspose] = MemoryInfo(offset_int_vec(Pretranspose), MemoryLifetime::Persistent, B_pretranspose_size, alignment);
}
// Handle indirect GEMM convolution
@@ -509,10 +403,11 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::configure(const ITensorInfo *
template <typename TypeInput, typename TypeOutput, class OutputStage>
void Fallback<TypeInput, TypeOutput, OutputStage>::prepare(ITensorPack &tensors)
{
- auto b = tensors.get_const_tensor(TensorType::ACL_SRC_1);
- auto c = tensors.get_const_tensor(TensorType::ACL_SRC_2);
if(!_is_prepared)
{
+ auto b = tensors.get_const_tensor(TensorType::ACL_SRC_1);
+ auto c = tensors.get_const_tensor(TensorType::ACL_SRC_2);
+
// Setup up matrix bias in the assembly kernel, it's just a pointer to matrix C.
if(c && c->info()->data_type() == DataType::S32)
{
@@ -526,24 +421,9 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::prepare(ITensorPack &tensors)
const auto in1_ptr = reinterpret_cast<const TypeInput *>(b->buffer() + b->info()->offset_first_element_in_bytes());
const int multi_stride_b = b->info()->strides_in_bytes().z() / sizeof(TypeInput);
- if(is_weight_managed())
- {
- _weights_transform.set_args(ldb, in1_ptr, multi_stride_b, _gemm_kernel_asm);
- _weights_manager->run(b, &_weights_transform);
-
- // If we didn't run the reshape function, set the pretransposed buffer
- if(!_weights_transform.is_reshape_run())
- {
- _weights_transform.set_pretranspose(_pretranspose);
- }
- }
- else
- {
- static_cast<Tensor *>(_pretranspose)->allocator()->allocate();
- ARM_COMPUTE_ERROR_ON(_pretranspose->buffer() == nullptr);
- _gemm_kernel_asm->pretranspose_B_array(_pretranspose->buffer(), in1_ptr, ldb, multi_stride_b);
- b->mark_as_unused();
- }
+ CpuAuxTensorHandler pretranspose(offset_int_vec(Pretranspose), _pretranspose_info, tensors, false);
+ ARM_COMPUTE_ERROR_ON(pretranspose.get()->buffer() == nullptr);
+ _gemm_kernel_asm->pretranspose_B_array(pretranspose.get()->buffer(), in1_ptr, ldb, multi_stride_b);
}
if(_gemm_info.method == AsmConvMethod::Indirect)
@@ -556,18 +436,15 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::prepare(ITensorPack &tensors)
}
template <typename TypeInput, typename TypeOutput, class OutputStage>
-void Fallback<TypeInput, TypeOutput, OutputStage>::allocate_workspace(size_t workspace_size, MemoryGroup &memory_group, size_t alignment)
+bool Fallback<TypeInput, TypeOutput, OutputStage>::is_configured() const
{
- ARM_COMPUTE_ERROR_ON_MSG(workspace_size == 0, "size cannot be 0");
- _workspace.allocator()->init(TensorInfo(TensorShape{ (workspace_size + alignment) }, 1, DataType::S8), alignment);
- memory_group.manage(&_workspace);
- _workspace.allocator()->allocate();
+ return _optimised_kernel != nullptr;
}
template <typename TypeInput, typename TypeOutput, class OutputStage>
-bool Fallback<TypeInput, TypeOutput, OutputStage>::is_configured() const
+experimental::MemoryRequirements Fallback<TypeInput, TypeOutput, OutputStage>::workspace() const
{
- return _optimised_kernel != nullptr;
+ return _aux_mem;
}
template <typename TypeInput, typename TypeOutput, class OutputStage>
@@ -609,9 +486,10 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::run(ITensorPack &tensors)
const auto scheduling_hint = scheduling_hint_heuristic(_kernel_info.method, d->info()->data_type());
// Set workspace if needed and reset number of threads as buffer manager gets re-created with max_threads
- if(_workspace.buffer() != nullptr)
+ CpuAuxTensorHandler workspace(offset_int_vec(AsmGemmWorkspace), _workspace_info, tensors, false);
+ if(workspace.get()->buffer() != nullptr)
{
- _gemm_kernel_asm->set_working_space(reinterpret_cast<void *>(_workspace.buffer()));
+ _gemm_kernel_asm->set_working_space(reinterpret_cast<void *>(workspace.get()->buffer()));
const unsigned int split_dim = scheduling_hint.split_dimension();
const unsigned int window_size = _gemm_kernel_asm->get_window_size().total_size();
unsigned int num_threads = NEScheduler::get().num_threads();
@@ -656,9 +534,9 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::run(ITensorPack &tensors)
}
template <typename TypeInput, typename TypeOutput>
-void create_arm_gemm(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &arm_gemm, MemoryGroup &memory_group,
- const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, ITensorInfo *d, arm_gemm::Activation activation, const AsmGemmInfo &info,
- IWeightsManager *weights_manager)
+void create_arm_gemm(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &arm_gemm,
+ const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, ITensorInfo *d,
+ arm_gemm::Activation activation, const AsmGemmInfo &info)
{
Params p = extract_parameters(a, b, d, info);
const CPUInfo &ci = NEScheduler::get().cpu_info();
@@ -668,14 +546,14 @@ void create_arm_gemm(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &arm_ge
// Create arm_gemm fallback
auto fallback = std::make_unique<Fallback<TypeInput, TypeOutput>>();
- fallback->configure(a, b, c, d, args, info, memory_group, weights_manager);
+ fallback->configure(a, b, c, d, args, info);
arm_gemm = std::move(fallback);
}
template <typename TypeInput, typename TypeOutput>
-void create_arm_gemm_quant(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &arm_gemm, MemoryGroup &memory_group,
- const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, ITensorInfo *d, arm_gemm::Activation activation, const AsmGemmInfo &info,
- IWeightsManager *weights_manager)
+void create_arm_gemm_quant(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &arm_gemm,
+ const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, ITensorInfo *d,
+ arm_gemm::Activation activation, const AsmGemmInfo &info)
{
ARM_COMPUTE_UNUSED(activation);
Params p = extract_parameters(a, b, d, info);
@@ -713,14 +591,13 @@ void create_arm_gemm_quant(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &
}
// Configure fallback
- fallback->configure(a, b, c, d, args, info, memory_group, weights_manager, gemm_requant_info);
+ fallback->configure(a, b, c, d, args, info, gemm_requant_info);
arm_gemm = std::move(fallback);
}
-
} //namespace
-CpuGemmAssemblyDispatch::CpuGemmAssemblyDispatch(std::shared_ptr<IMemoryManager> memory_manager, IWeightsManager *weights_manager)
- : _arm_gemm(nullptr), _memory_group(std::move(memory_manager)), _weights_manager(weights_manager)
+CpuGemmAssemblyDispatch::CpuGemmAssemblyDispatch()
+ : _arm_gemm(nullptr)
{
}
@@ -775,40 +652,40 @@ void CpuGemmAssemblyDispatch::configure(const ITensorInfo *a, const ITensorInfo
switch(a->data_type())
{
case DataType::F32:
- create_arm_gemm<float, float>(_arm_gemm, _memory_group, a, b, c, d, act, info, _weights_manager);
+ create_arm_gemm<float, float>(_arm_gemm, a, b, c, d, act, info);
break;
#ifdef __aarch64__
case DataType::U8:
case DataType::QASYMM8:
if(d->data_type() == DataType::S32)
{
- create_arm_gemm<uint8_t, uint32_t>(_arm_gemm, _memory_group, a, b, c, d, act, info, _weights_manager);
+ create_arm_gemm<uint8_t, uint32_t>(_arm_gemm, a, b, c, d, act, info);
}
else
{
- create_arm_gemm_quant<uint8_t, uint8_t>(_arm_gemm, _memory_group, a, b, c, d, act, info, _weights_manager);
+ create_arm_gemm_quant<uint8_t, uint8_t>(_arm_gemm, a, b, c, d, act, info);
}
break;
case DataType::S8:
case DataType::QASYMM8_SIGNED:
if(d->data_type() == DataType::S32)
{
- create_arm_gemm<int8_t, int32_t>(_arm_gemm, _memory_group, a, b, c, d, act, info, _weights_manager);
+ create_arm_gemm<int8_t, int32_t>(_arm_gemm, a, b, c, d, act, info);
}
else
{
- create_arm_gemm_quant<int8_t, int8_t>(_arm_gemm, _memory_group, a, b, c, d, act, info, _weights_manager);
+ create_arm_gemm_quant<int8_t, int8_t>(_arm_gemm, a, b, c, d, act, info);
}
break;
#endif /* __aarch64__ */
#if defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16)
case DataType::BFLOAT16:
- create_arm_gemm<bfloat16, float>(_arm_gemm, _memory_group, a, b, c, d, act, info, _weights_manager);
+ create_arm_gemm<bfloat16, float>(_arm_gemm, a, b, c, d, act, info);
break;
#endif /* defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16) */
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
case DataType::F16:
- create_arm_gemm<float16_t, float16_t>(_arm_gemm, _memory_group, a, b, c, d, act, info, _weights_manager);
+ create_arm_gemm<float16_t, float16_t>(_arm_gemm, a, b, c, d, act, info);
break;
#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
default:
@@ -829,10 +706,14 @@ bool CpuGemmAssemblyDispatch::is_configured() const
void CpuGemmAssemblyDispatch::run(ITensorPack &tensors)
{
- MemoryGroupResourceScope scope_mg(_memory_group);
-
ARM_COMPUTE_ERROR_ON(_arm_gemm == nullptr);
_arm_gemm->run(tensors);
}
+
+experimental::MemoryRequirements CpuGemmAssemblyDispatch::workspace() const
+{
+ ARM_COMPUTE_ERROR_ON(_arm_gemm == nullptr);
+ return _arm_gemm->workspace();
+}
} // namespace cpu
} // namespace arm_compute
diff --git a/src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.h b/src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.h
index ffc097c75c..355273adeb 100644
--- a/src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.h
+++ b/src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.h
@@ -24,10 +24,6 @@
#ifndef ARM_COMPUTE_CPU_INTERNAL_CPU_GEMM_ASSEMBLY_DISPATCH_H
#define ARM_COMPUTE_CPU_INTERNAL_CPU_GEMM_ASSEMBLY_DISPATCH_H
-#include "arm_compute/runtime/IMemoryManager.h"
-#include "arm_compute/runtime/IWeightsManager.h"
-#include "arm_compute/runtime/MemoryGroup.h"
-#include "arm_compute/runtime/Tensor.h"
#include "src/core/common/Macros.h"
#include "src/runtime/cpu/ICpuOperator.h"
@@ -62,7 +58,7 @@ class CpuGemmAssemblyDispatch : public ICpuOperator
{
public:
/** Constructor */
- CpuGemmAssemblyDispatch(std::shared_ptr<IMemoryManager> memory_manager = nullptr, IWeightsManager *weights_manager = nullptr);
+ CpuGemmAssemblyDispatch();
/** Defautl destructor */
~CpuGemmAssemblyDispatch() = default;
@@ -71,10 +67,11 @@ public:
class IFallback
{
public:
- virtual void run(ITensorPack &tensors) = 0;
- virtual void prepare(ITensorPack &tensors) = 0;
- virtual bool is_configured() const = 0;
- virtual ~IFallback() = default;
+ virtual void run(ITensorPack &tensors) = 0;
+ virtual void prepare(ITensorPack &tensors) = 0;
+ virtual experimental::MemoryRequirements workspace() const = 0;
+ virtual bool is_configured() const = 0;
+ virtual ~IFallback() = default;
};
public:
@@ -115,11 +112,10 @@ public:
// Inherited methods overridden:
void prepare(ITensorPack &tensors) override;
void run(ITensorPack &tensors) override;
+ experimental::MemoryRequirements workspace() const override;
private:
- std::unique_ptr<IFallback> _arm_gemm; /**< Interface for the arm_gemm fallback */
- MemoryGroup _memory_group; /**< Function memory group */
- IWeightsManager *_weights_manager; /**< Pointer to the weights manager */
+ std::unique_ptr<IFallback> _arm_gemm; /**< Interface for the arm_gemm fallback */
};
} // namespace cpu
} // namespace arm_compute