aboutsummaryrefslogtreecommitdiff
path: root/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp')
-rw-r--r--src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp104
1 files changed, 92 insertions, 12 deletions
diff --git a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
index 343ef21c0b..82bd465c99 100644
--- a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
+++ b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
@@ -31,6 +31,7 @@
#include "src/core/utils/AssemblyUtils.h"
#include "src/cpu/kernels/assembly/arm_gemm.hpp"
#include "src/cpu/kernels/assembly/CpuGemmAssemblyWrapperKernel.h"
+#include "src/cpu/operators/CpuTranspose.h"
#include "src/cpu/utils/CpuAuxTensorHandler.h"
#include <arm_neon.h>
@@ -229,6 +230,7 @@ private:
enum AuxTensorIdx
{
AsmGemmWorkspace = 0,
+ PrePretransposedB, /* Transposed B (rhs) before being passed to gemm or pretranspose_B_array */
Pretranspose,
Count
};
@@ -244,12 +246,16 @@ private:
/** Prepare the indirect buffer */
void prepare_indirect_buffer(ITensorPack &tensors);
+ /** Operator to transpose B before gemm or pretranspose_B_array*/
+ std::unique_ptr<CpuTranspose> _pre_pretranspose_b{nullptr};
/** Assembly Gemm kernel */
std::shared_ptr<arm_gemm::GemmCommon<TypeInput, TypeOutput>> _gemm_kernel_asm{nullptr};
/** Optimised Arm® Neon™ kernel */
std::unique_ptr<INEKernel> _optimised_kernel{nullptr};
/** Assembly GEMM workspace tensor info */
TensorInfo _workspace_info{};
+ /** Pre-pre-transposed B tensor info */
+ TensorInfo _pre_pretransposed_b_info{};
/** Pre-transpose tensor info */
TensorInfo _pretranspose_info{};
/** Prepared flag */
@@ -473,9 +479,45 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::configure(const ITensorInfo *
_optimised_kernel = std::move(acl_gemm_wrapper);
_gemm_info = gemm_info;
+ // Check if we need to pre-pretranspose B. Fixed format kernels need no pre-pretranspose.
+ const bool run_pre_pretranspose_b = _gemm_info.transpose_b && !isVarWeightsKernel();
+ if (run_pre_pretranspose_b)
+ {
+ _pre_pretranspose_b = std::make_unique<CpuTranspose>();
+ _pre_pretranspose_b->configure(b, &_pre_pretransposed_b_info);
+ MemoryLifetime lifetime;
+ if (_is_b_constant)
+ {
+ if (_gemm_kernel_asm->B_pretranspose_required())
+ {
+ // PrePretransposedB tensor is only used in prepare(), but is then succeeded by Pretranspose
+ // So PrePretransposedB can be freed inside prepare()
+ lifetime = MemoryLifetime::Prepare;
+ }
+ else
+ {
+ // PrePretransposedB tensor is only used in prepare(), but is the final transformation of B
+ // So PrePretransposedB needs to persist beyond prepare()
+ lifetime = MemoryLifetime::Persistent;
+ }
+ }
+ else
+ {
+ // PrePretransposedB tensor is always used in run() and doesn't need to persist
+ lifetime = MemoryLifetime::Temporary;
+ }
+ // Forcing 128-byte alignment (required by 32-bit kernels)
+ const unsigned int alignment = 128;
+ _aux_mem[PrePretransposedB] =
+ MemoryInfo(offset_int_vec(PrePretransposedB), lifetime, _pre_pretransposed_b_info.total_size(), alignment);
+ }
+
// Check for pre-transposed support
if (_gemm_kernel_asm->B_pretranspose_required())
{
+ // Fixed format kernels need no pretranspose.
+ ARM_COMPUTE_ERROR_ON(arm_compute::is_fixed_format(
+ assembly_utils::map_to_arm_compute_weight_format(_gemm_kernel_asm->get_config().weight_format)));
// Forcing 128-byte alignment (required by 32-bit kernels)
const unsigned int alignment = 128;
const size_t B_pretranspose_size = _gemm_kernel_asm->get_B_pretransposed_array_size();
@@ -506,6 +548,22 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::prepare(ITensorPack &tensors)
_gemm_kernel_asm->set_quantized_bias(
reinterpret_cast<const int32_t *>(c->buffer() + c->info()->offset_first_element_in_bytes()), 0);
}
+ const ITensor *b_to_use = b;
+ // Pre-pretranspose B if required
+ const bool run_pre_pretranspose_b = _gemm_info.transpose_b && !isVarWeightsKernel();
+ CpuAuxTensorHandler pre_pretransposed_b(
+ offset_int_vec(PrePretransposedB), _pre_pretransposed_b_info, tensors,
+ /*pack_inject: no need to inject into tensors*/
+ false,
+ /*bypass_alloc: no need to allocate if pre-pretranspose B is not required as this handle will not be used*/
+ !run_pre_pretranspose_b);
+ if (run_pre_pretranspose_b)
+ {
+ ARM_COMPUTE_ERROR_ON(_pre_pretranspose_b == nullptr);
+ ITensorPack pre_pretranspose_pack{{ACL_SRC, b_to_use}, {ACL_DST, pre_pretransposed_b.get()}};
+ _pre_pretranspose_b->run(pre_pretranspose_pack);
+ b_to_use = pre_pretransposed_b.get();
+ }
// Pretranspose B if required
if (_gemm_kernel_asm->B_pretranspose_required())
@@ -513,10 +571,10 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::prepare(ITensorPack &tensors)
// Fixed format kernels need no pretranspose.
ARM_COMPUTE_ERROR_ON(arm_compute::is_fixed_format(
assembly_utils::map_to_arm_compute_weight_format(_gemm_kernel_asm->get_config().weight_format)));
- const int ldb = b->info()->strides_in_bytes().y() / b->info()->element_size();
- const auto in1_ptr =
- reinterpret_cast<const TypeInput *>(b->buffer() + b->info()->offset_first_element_in_bytes());
- const int multi_stride_b = b->info()->strides_in_bytes().z() / b->info()->element_size();
+ const int ldb = b_to_use->info()->strides_in_bytes().y() / b_to_use->info()->element_size();
+ const auto in1_ptr = reinterpret_cast<const TypeInput *>(b_to_use->buffer() +
+ b_to_use->info()->offset_first_element_in_bytes());
+ const int multi_stride_b = b_to_use->info()->strides_in_bytes().z() / b_to_use->info()->element_size();
CpuAuxTensorHandler pretranspose(offset_int_vec(Pretranspose), _pretranspose_info, tensors, false);
ARM_COMPUTE_ERROR_ON(pretranspose.get()->buffer() == nullptr);
@@ -525,6 +583,7 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::prepare(ITensorPack &tensors)
NEScheduler::get().num_threads());
b->mark_as_unused();
+ // Note that we don't need to mark b_to_use as unused, as if it's been assigned to pre_pretransposed_b, its memory will be auto-managed by the handler
}
if (_gemm_info.method == AsmConvMethod::Indirect)
@@ -576,16 +635,33 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::run(ITensorPack &tensors)
const TypeInput *in1_ptr = nullptr;
auto out_ptr = reinterpret_cast<TypeOutput *>(d->buffer() + d->info()->offset_first_element_in_bytes());
+ const ITensor *b_to_use = b;
+
+ // Pre-pretranspose B if required
+ const bool run_pre_pretranspose_b = _gemm_info.transpose_b && !isVarWeightsKernel();
+ CpuAuxTensorHandler pre_pretransposed_b(
+ offset_int_vec(PrePretransposedB), _pre_pretransposed_b_info, tensors,
+ false /*pack_inject: no need to inject into tensors*/,
+ !run_pre_pretranspose_b /*bypass_alloc: no need to allocate if pre-pretranspose B is not required as this handle will not be used*/);
+ if (b_to_use && !_is_b_constant && run_pre_pretranspose_b)
+ {
+ ARM_COMPUTE_ERROR_ON(_pre_pretranspose_b == nullptr);
+ ITensorPack pre_pretranspose_pack{{ACL_SRC, b_to_use}, {ACL_DST, pre_pretransposed_b.get()}};
+ _pre_pretranspose_b->run(pre_pretranspose_pack);
+ b_to_use = pre_pretransposed_b.get();
+ }
+
// Check if B is pre-tranposed and de-reference if not
if (!_gemm_kernel_asm->B_is_pretransposed())
{
- ldb = b->info()->strides_in_bytes().y() / b->info()->element_size();
- multi_stride_b = b->info()->strides_in_bytes().z() / b->info()->element_size();
- in1_ptr = reinterpret_cast<const TypeInput *>(b->buffer() + b->info()->offset_first_element_in_bytes());
+ ldb = b_to_use->info()->strides_in_bytes().y() / b_to_use->info()->element_size();
+ multi_stride_b = b_to_use->info()->strides_in_bytes().z() / b_to_use->info()->element_size();
+ in1_ptr =
+ reinterpret_cast<const TypeInput *>(b_to_use->buffer() + b_to_use->info()->offset_first_element_in_bytes());
}
// If necessary, run pretranspose every time if either weights or biases are non-constant
- if ((b && !_is_b_constant) || (c && !_is_c_constant && c->info()->data_type() == DataType::S32))
+ if ((b_to_use && !_is_b_constant) || (c && !_is_c_constant && c->info()->data_type() == DataType::S32))
{
if (c && c->info()->data_type() == DataType::S32)
{
@@ -596,10 +672,13 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::run(ITensorPack &tensors)
// Pretranspose B if required
if (_B_pretranspose_required)
{
- const int ldb = b->info()->strides_in_bytes().y() / b->info()->element_size();
- const auto b_ptr =
- reinterpret_cast<const TypeInput *>(b->buffer() + b->info()->offset_first_element_in_bytes());
- const int multi_stride_b = b->info()->strides_in_bytes().z() / b->info()->element_size();
+ // Fixed format kernels need no pretranspose.
+ ARM_COMPUTE_ERROR_ON(arm_compute::is_fixed_format(
+ assembly_utils::map_to_arm_compute_weight_format(_gemm_kernel_asm->get_config().weight_format)));
+ const int ldb = b_to_use->info()->strides_in_bytes().y() / b_to_use->info()->element_size();
+ const auto b_ptr = reinterpret_cast<const TypeInput *>(b_to_use->buffer() +
+ b_to_use->info()->offset_first_element_in_bytes());
+ const int multi_stride_b = b_to_use->info()->strides_in_bytes().z() / b_to_use->info()->element_size();
CpuAuxTensorHandler pretranspose(offset_int_vec(Pretranspose), _pretranspose_info, tensors, true);
ARM_COMPUTE_ERROR_ON(pretranspose.get()->buffer() == nullptr);
@@ -762,6 +841,7 @@ Status CpuGemmAssemblyDispatch::has_opt_impl(arm_compute::WeightFormat &expected
arm_gemm::WeightFormat arm_gemm_expected_wf = assembly_utils::map_to_arm_gemm_weight_format(expected_weight_format);
arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, act, num_threads,
info.fixed_format, info.fast_mode, &cfg);
+ // TODO: Incorporate info.transpose_b COMPMID-6595
switch (a->data_type())
{
case DataType::F32: