aboutsummaryrefslogtreecommitdiff
path: root/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp
diff options
context:
space:
mode:
authorMichalis Spyrou <michalis.spyrou@arm.com>2019-09-10 17:20:34 +0100
committerMichalis Spyrou <michalis.spyrou@arm.com>2019-09-26 10:17:30 +0000
commit1a569a30a2f456ff1a3e0a665201e1c3ab92df80 (patch)
tree9d68934f461579edefbe65246f6ee435aaa18808 /src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp
parentf1cf394ae882e6e8fb2e0986f88d2548b82a85bb (diff)
downloadComputeLibrary-1a569a30a2f456ff1a3e0a665201e1c3ab92df80.tar.gz
COMPMID-2161 [NEON] Create IWeightManager class
Change-Id: I1a9a46da2f98e896b825099151b56d1d8271dd31 Signed-off-by: Michalis Spyrou <michalis.spyrou@arm.com> Reviewed-on: https://review.mlplatform.org/c/1915 Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp')
-rw-r--r--src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp163
1 files changed, 131 insertions, 32 deletions
diff --git a/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp
index 2a4498b0a9..956ded55d2 100644
--- a/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp
+++ b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp
@@ -38,7 +38,8 @@ namespace
std::unique_ptr<IFunction> create_function_all_types(const arm_gemm::KernelDescription &gemm_kernel_info,
const ITensor *a, const ITensor *b, ITensor *d,
float alpha, float beta, const GEMMInfo &gemm_info,
- std::shared_ptr<IMemoryManager> memory_manager)
+ std::shared_ptr<IMemoryManager> memory_manager,
+ IWeightsManager *weights_manager)
{
// Note: It's safe to not check for FP16 support because this was already checked in NEGEMMAssemblyDispatch::configure()
@@ -50,7 +51,7 @@ std::unique_ptr<IFunction> create_function_all_types(const arm_gemm::KernelDescr
{
return nullptr;
}
- auto function = support::cpp14::make_unique<NEGEMMInterleavedWrapper>(memory_manager);
+ auto function = support::cpp14::make_unique<NEGEMMInterleavedWrapper>(memory_manager, weights_manager);
function->configure(a, b, d, alpha, beta, gemm_info);
return std::move(function);
}
@@ -73,25 +74,95 @@ std::unique_ptr<IFunction> create_function_all_types(const arm_gemm::KernelDescr
}
}
+template <typename TypeInput, typename TypeOutput>
+class FallbackTransform : public ITransformWeights
+{
+public:
+ void run() override
+ {
+ _output.allocator()->allocate();
+ ARM_COMPUTE_ERROR_ON(_output.buffer() == nullptr);
+ _gemm_kernel_asm->pretranspose_B_array(_output.buffer(), _in1_ptr, _ldb, _multi_stride_b);
+ _reshape_run = true;
+ }
+
+ void release() override
+ {
+ _output.allocator()->free();
+ }
+
+ ITensor *get_weights() override
+ {
+ return &_output;
+ }
+
+ uint32_t uid() override
+ {
+ uint32_t id = (_B_pretranspose_size | 0x80000000);
+ return id;
+ }
+
+ void configure(size_t B_pretranspose_size, unsigned int alignment)
+ {
+ _output.allocator()->init(TensorInfo(TensorShape{ (B_pretranspose_size + alignment /* FIXME: remove alignment after COMPMID-1088 */) }, 1, DataType::S8), alignment);
+ _B_pretranspose_size = B_pretranspose_size;
+ }
+
+ void set_pretranspose(ITensor *tensor)
+ {
+ if(!_reshape_run)
+ {
+ _gemm_kernel_asm->set_pretransposed_B_data(tensor->buffer());
+ }
+ }
+
+ void set_args(const int ldb, const TypeInput *in1_ptr, const int multi_stride_b, std::shared_ptr<arm_gemm::GemmCommon<TypeInput, TypeOutput>> gemm_kernel_asm)
+ {
+ _ldb = ldb;
+ _in1_ptr = in1_ptr;
+ _multi_stride_b = multi_stride_b;
+ _gemm_kernel_asm = gemm_kernel_asm;
+ }
+
+private:
+ Tensor _output{};
+ int _ldb{};
+ const TypeInput *_in1_ptr{};
+ int _multi_stride_b{};
+ size_t _B_pretranspose_size{};
+ std::shared_ptr<arm_gemm::GemmCommon<TypeInput, TypeOutput>> _gemm_kernel_asm{ nullptr };
+};
+
/** Fallback in case ACL doesn't have a function */
template <typename TypeInput, typename TypeOutput, class OutputStage = arm_gemm::Nothing>
class Fallback : public NEGEMMAssemblyDispatch::IFallback
{
public:
+ /** Destructor */
+ ~Fallback()
+ {
+ // Release memory if we have allocated the memory ourselves
+ if(_pretranspose && !(_weights_manager && _weights_manager->are_weights_managed(_b)))
+ {
+ delete _pretranspose;
+ }
+ }
+
/** Initialise the functions's input and output.
*
- * @param[in] a Input tensor containing the Matrix A.
- * @param[in] b Input tensor containing the Matrix B.
- * @param[in] c Input tensor containing the Matrix C.
- * @param[out] d Output tensor to store the result of matrix multiplication.
- * @param[in] args Matrix multiplication information.
- * @param[in] gemm_info GEMM meta-data
- * @param[in] memory_group Memory group to be used by the function.
- * @param[in] os Output stage meta-data.
+ * @param[in] a Input tensor containing the Matrix A.
+ * @param[in] b Input tensor containing the Matrix B.
+ * @param[in] c Input tensor containing the Matrix C.
+ * @param[out] d Output tensor to store the result of matrix multiplication.
+ * @param[in] args Matrix multiplication information.
+ * @param[in] gemm_info GEMM meta-data
+ * @param[in] memory_group Memory group to be used by the function.
+ * @param[in] weights_manager Weights manager to be used by the function.
+ * @param[in] os Output stage meta-data.
*/
void configure(const ITensor *a, const ITensor *b, const ITensor *c, ITensor *d,
arm_gemm::GemmArgs<TypeOutput> args, const GEMMInfo &gemm_info,
- MemoryGroup &memory_group, const OutputStage &os = {});
+ MemoryGroup &memory_group, IWeightsManager *weights_manager, const OutputStage &os = {});
// Inherited methods overridden:
void run() override;
@@ -108,7 +179,7 @@ private:
void allocate_workspace(size_t workspace_size, MemoryGroup &memory_group, size_t alignment);
/** Assembly Gemm kernel */
- std::unique_ptr<arm_gemm::GemmCommon<TypeInput, TypeOutput>> _gemm_kernel_asm{ nullptr };
+ std::shared_ptr<arm_gemm::GemmCommon<TypeInput, TypeOutput>> _gemm_kernel_asm{ nullptr };
/** Optimised NEON kernel */
std::unique_ptr<INEKernel> _optimised_kernel{ nullptr };
/** Input A */
@@ -130,20 +201,25 @@ private:
/** GEMM workspace */
Tensor _workspace{};
/** Pre-transpose tensor */
- Tensor _pretranspose{};
+ ITensor *_pretranspose{ nullptr };
/** Prepared flag */
bool _is_prepared{ false };
/** GEMM meta-data */
GEMMInfo _gemm_info{};
+ /** Weights manager */
+ IWeightsManager *_weights_manager{ nullptr };
+ /** Weights transform object */
+ FallbackTransform<TypeInput, TypeOutput> _weights_transform{};
};
template <typename TypeInput, typename TypeOutput, class OutputStage>
void Fallback<TypeInput, TypeOutput, OutputStage>::configure(const ITensor *a, const ITensor *b, const ITensor *c, ITensor *d,
arm_gemm::GemmArgs<TypeOutput> args, const GEMMInfo &gemm_info,
- MemoryGroup &memory_group, const OutputStage &os)
+ MemoryGroup &memory_group, IWeightsManager *weights_manager, const OutputStage &os)
{
arm_gemm::GemmConfig gemm_cfg;
const arm_gemm::KernelDescription gemm_kernel_info = arm_gemm::get_gemm_method<TypeInput, TypeOutput, OutputStage>(args, os);
+ _weights_manager = weights_manager;
if(gemm_kernel_info.method != arm_gemm::GemmMethod::GEMV_BATCHED)
{
gemm_cfg.filter = gemm_kernel_info.name;
@@ -190,7 +266,16 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::configure(const ITensor *a, c
// Forcing 128-byte alignment (required by 32-bit kernels)
const unsigned int alignment = 128;
const size_t B_pretranspose_size = _gemm_kernel_asm->get_B_pretransposed_array_size();
- _pretranspose.allocator()->init(TensorInfo(TensorShape{ (B_pretranspose_size + alignment /* FIXME: remove alignment after COMPMID-1088 */) }, 1, DataType::S8), alignment);
+ if(weights_manager && _weights_manager->are_weights_managed(b))
+ {
+ _weights_transform.configure(B_pretranspose_size, alignment);
+ _pretranspose = _weights_manager->acquire(b, &_weights_transform);
+ }
+ else
+ {
+ _pretranspose = new Tensor();
+ static_cast<Tensor *>(_pretranspose)->allocator()->init(TensorInfo(TensorShape{ (B_pretranspose_size + alignment /* FIXME: remove alignment after COMPMID-1088 */) }, 1, DataType::S8), alignment);
+ }
}
}
@@ -208,14 +293,28 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::prepare()
// Pretranspose B if required
if(_gemm_kernel_asm->B_pretranspose_required())
{
- _pretranspose.allocator()->allocate();
- ARM_COMPUTE_ERROR_ON(_pretranspose.buffer() == nullptr);
const int ldb = _b->info()->strides_in_bytes().y() / sizeof(TypeInput);
const auto in1_ptr = reinterpret_cast<const TypeInput *>(_b->buffer() + _b->info()->offset_first_element_in_bytes());
const int multi_stride_b = _b->info()->strides_in_bytes().z() / sizeof(TypeInput);
- _gemm_kernel_asm->pretranspose_B_array(_pretranspose.buffer(), in1_ptr, ldb, multi_stride_b);
- _b->mark_as_unused();
+ if(_weights_manager && _weights_manager->are_weights_managed(_b))
+ {
+ _weights_transform.set_args(ldb, in1_ptr, multi_stride_b, _gemm_kernel_asm);
+ _weights_manager->run(_b, &_weights_transform);
+
+ // If we didn't run the reshape function, set the pretransposed buffer
+ if(!_weights_transform.is_reshape_run())
+ {
+ _weights_transform.set_pretranspose(_pretranspose);
+ }
+ }
+ else
+ {
+ static_cast<Tensor *>(_pretranspose)->allocator()->allocate();
+ ARM_COMPUTE_ERROR_ON(_pretranspose->buffer() == nullptr);
+ _gemm_kernel_asm->pretranspose_B_array(_pretranspose->buffer(), in1_ptr, ldb, multi_stride_b);
+ _b->mark_as_unused();
+ }
}
_is_prepared = true;
@@ -294,7 +393,7 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::run()
template <typename TypeInput, typename TypeOutput>
void create_function_or_arm_gemm(std::unique_ptr<IFunction> &acl_function, std::unique_ptr<NEGEMMAssemblyDispatch::IFallback> &arm_gemm, MemoryGroup &memory_group,
const ITensor *a, const ITensor *b, const ITensor *c, ITensor *d, float alpha, float beta, const GEMMInfo &gemm_info,
- std::shared_ptr<IMemoryManager> memory_manager)
+ std::shared_ptr<IMemoryManager> memory_manager, IWeightsManager *weights_manager)
{
INEGEMMWrapperKernel::Params p = INEGEMMWrapperKernel::extract_parameters(a, b, d, gemm_info);
const CPUInfo &ci = NEScheduler::get().cpu_info();
@@ -304,14 +403,14 @@ void create_function_or_arm_gemm(std::unique_ptr<IFunction> &acl_function, std::
// Try to create an ACL function:
const arm_gemm::KernelDescription gemm_kernel_info = arm_gemm::get_gemm_method<TypeInput, TypeOutput>(args);
- acl_function = create_function_all_types(gemm_kernel_info, a, b, d, alpha, beta, gemm_info, std::move(memory_manager));
+ acl_function = create_function_all_types(gemm_kernel_info, a, b, d, alpha, beta, gemm_info, std::move(memory_manager), weights_manager);
// If we still don't have an ACL function:
if(acl_function == nullptr)
{
//Fallback onto arm_gemm function if ACL doesn't support this method.
auto fallback = support::cpp14::make_unique<Fallback<TypeInput, TypeOutput>>();
- fallback->configure(a, b, c, d, args, gemm_info, memory_group);
+ fallback->configure(a, b, c, d, args, gemm_info, memory_group, weights_manager);
arm_gemm = std::move(fallback);
}
}
@@ -319,7 +418,7 @@ void create_function_or_arm_gemm(std::unique_ptr<IFunction> &acl_function, std::
template <typename TypeInput, typename TypeOutput>
void create_function_or_arm_gemm_quant(std::unique_ptr<IFunction> &acl_function, std::unique_ptr<NEGEMMAssemblyDispatch::IFallback> &arm_gemm, MemoryGroup &memory_group,
const ITensor *a, const ITensor *b, const ITensor *c, ITensor *d, float alpha, float beta, const GEMMInfo &gemm_info,
- std::shared_ptr<IMemoryManager> memory_manager)
+ std::shared_ptr<IMemoryManager> memory_manager, IWeightsManager *weights_manager)
{
INEGEMMWrapperKernel::Params p = INEGEMMWrapperKernel::extract_parameters(a, b, d, gemm_info);
const CPUInfo &ci = NEScheduler::get().cpu_info();
@@ -339,22 +438,22 @@ void create_function_or_arm_gemm_quant(std::unique_ptr<IFunction> &acl_function,
// Try to create an ACL function:
const arm_gemm::KernelDescription gemm_kernel_info = arm_gemm::get_gemm_method<TypeInput, TypeOutput>(args, gemm_requant_info);
- acl_function = create_function_all_types(gemm_kernel_info, a, b, d, alpha, beta, gemm_info, std::move(memory_manager));
+ acl_function = create_function_all_types(gemm_kernel_info, a, b, d, alpha, beta, gemm_info, std::move(memory_manager), weights_manager);
// If we still don't have an ACL function:
if(acl_function == nullptr)
{
// Fallback onto arm_gemm function if ACL doesn't support this method.
auto fallback = support::cpp14::make_unique<Fallback<TypeInput, TypeOutput, arm_gemm::ARequantizeLayer32>>();
- fallback->configure(a, b, c, d, args, gemm_info, memory_group, gemm_requant_info);
+ fallback->configure(a, b, c, d, args, gemm_info, memory_group, weights_manager, gemm_requant_info);
arm_gemm = std::move(fallback);
}
}
} //namespace
-NEGEMMAssemblyDispatch::NEGEMMAssemblyDispatch(std::shared_ptr<IMemoryManager> memory_manager)
- : _function(nullptr), _arm_gemm(nullptr), _memory_group(memory_manager), _memory_manager(memory_manager)
+NEGEMMAssemblyDispatch::NEGEMMAssemblyDispatch(std::shared_ptr<IMemoryManager> memory_manager, IWeightsManager *weights_manager)
+ : _function(nullptr), _arm_gemm(nullptr), _memory_group(memory_manager), _memory_manager(memory_manager), _weights_manager(weights_manager)
{
}
@@ -390,27 +489,27 @@ void NEGEMMAssemblyDispatch::configure(const ITensor *a, const ITensor *b, const
switch(a->info()->data_type())
{
case DataType::F32:
- create_function_or_arm_gemm<float, float>(_function, _arm_gemm, _memory_group, a, b, c, d, alpha, beta, gemm_info, _memory_manager);
+ create_function_or_arm_gemm<float, float>(_function, _arm_gemm, _memory_group, a, b, c, d, alpha, beta, gemm_info, _memory_manager, _weights_manager);
break;
#ifdef __aarch64__
case DataType::U8:
case DataType::QASYMM8:
if(d->info()->data_type() == DataType::S32)
{
- create_function_or_arm_gemm<uint8_t, uint32_t>(_function, _arm_gemm, _memory_group, a, b, c, d, alpha, beta, gemm_info, _memory_manager);
+ create_function_or_arm_gemm<uint8_t, uint32_t>(_function, _arm_gemm, _memory_group, a, b, c, d, alpha, beta, gemm_info, _memory_manager, _weights_manager);
}
else
{
- create_function_or_arm_gemm_quant<uint8_t, uint8_t>(_function, _arm_gemm, _memory_group, a, b, c, d, alpha, beta, gemm_info, _memory_manager);
+ create_function_or_arm_gemm_quant<uint8_t, uint8_t>(_function, _arm_gemm, _memory_group, a, b, c, d, alpha, beta, gemm_info, _memory_manager, _weights_manager);
}
break;
case DataType::S8:
- create_function_or_arm_gemm<int8_t, int32_t>(_function, _arm_gemm, _memory_group, a, b, c, d, alpha, beta, gemm_info, _memory_manager);
+ create_function_or_arm_gemm<int8_t, int32_t>(_function, _arm_gemm, _memory_group, a, b, c, d, alpha, beta, gemm_info, _memory_manager, _weights_manager);
break;
#endif /* __aarch64__ */
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
case DataType::F16:
- create_function_or_arm_gemm<float16_t, float16_t>(_function, _arm_gemm, _memory_group, a, b, c, d, alpha, beta, gemm_info, _memory_manager);
+ create_function_or_arm_gemm<float16_t, float16_t>(_function, _arm_gemm, _memory_group, a, b, c, d, alpha, beta, gemm_info, _memory_manager, _weights_manager);
break;
#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
default: