aboutsummaryrefslogtreecommitdiff
path: root/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp')
-rw-r--r--src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp62
1 files changed, 38 insertions, 24 deletions
diff --git a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
index 611bc76463..58ee68fd49 100644
--- a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
+++ b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2023 Arm Limited.
+ * Copyright (c) 2018-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -60,7 +60,8 @@ void run_parallel_pretranspose_B_array(arm_gemm::GemmCommon<TypeInput, TypeOutpu
const TypeInput *src,
int src_ld,
int src_multi_stride,
- unsigned int num_threads)
+ unsigned int num_threads,
+ bool transpose)
{
ARM_COMPUTE_ERROR_ON(gemm_asm == nullptr);
ARM_COMPUTE_ERROR_ON(num_threads == 0);
@@ -77,7 +78,8 @@ void run_parallel_pretranspose_B_array(arm_gemm::GemmCommon<TypeInput, TypeOutpu
if (start < end)
{
- gemm_asm->pretranspose_B_array_part(dst->buffer(), src, src_ld, src_multi_stride, start, end);
+ gemm_asm->pretranspose_B_array_part(dst->buffer(), src, src_ld, src_multi_stride, transpose, start,
+ end);
}
};
}
@@ -279,6 +281,8 @@ private:
bool _B_pretranspose_required{false};
bool _is_b_constant{true};
bool _is_c_constant{true};
+ bool _run_pre_pretranspose_b{false};
+ bool _B_pre_pretranspose_required{false};
};
template <typename TypeInput, typename TypeOutput, class OutputStage>
@@ -443,8 +447,6 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::configure(const ITensorInfo *
const AsmGemmInfo &gemm_info,
const OutputStage &os)
{
- ARM_COMPUTE_UNUSED(c);
-
_is_b_constant = b->are_values_constant();
_is_c_constant = c ? c->are_values_constant() : true;
@@ -479,16 +481,23 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::configure(const ITensorInfo *
_optimised_kernel = std::move(acl_gemm_wrapper);
_gemm_info = gemm_info;
+
// Check if we need to pre-pretranspose B. Fixed format kernels need no pre-pretranspose.
- const bool run_pre_pretranspose_b = _gemm_info.transpose_b && !isVarWeightsKernel();
- if (run_pre_pretranspose_b)
+ _B_pre_pretranspose_required = _gemm_info.transpose_b && !isVarWeightsKernel();
+ _B_pretranspose_required = _gemm_kernel_asm->B_pretranspose_required();
+
+ const bool kernel_supports_transpose = _gemm_kernel_asm->B_pretranspose_supports_transpose();
+ const bool kernel_can_fuse_transpose = _B_pretranspose_required && kernel_supports_transpose;
+ _run_pre_pretranspose_b = _B_pre_pretranspose_required && !kernel_can_fuse_transpose;
+
+ if (_run_pre_pretranspose_b)
{
_pre_pretranspose_b = std::make_unique<CpuTranspose>();
_pre_pretranspose_b->configure(b, &_pre_pretransposed_b_info);
MemoryLifetime lifetime;
if (_is_b_constant)
{
- if (_gemm_kernel_asm->B_pretranspose_required())
+ if (_B_pretranspose_required)
{
// PrePretransposedB tensor is only used in prepare(), but is then succeeded by Pretranspose
// So PrePretransposedB can be freed inside prepare()
@@ -513,7 +522,7 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::configure(const ITensorInfo *
}
// Check for pre-transposed support
- if (_gemm_kernel_asm->B_pretranspose_required())
+ if (_B_pretranspose_required)
{
// Fixed format kernels need no pretranspose.
ARM_COMPUTE_ERROR_ON(arm_compute::is_fixed_format(
@@ -524,7 +533,6 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::configure(const ITensorInfo *
_pretranspose_info = TensorInfo(TensorShape(B_pretranspose_size), 1, DataType::U8);
_aux_mem[Pretranspose] =
MemoryInfo(offset_int_vec(Pretranspose), MemoryLifetime::Persistent, B_pretranspose_size, alignment);
- _B_pretranspose_required = true;
}
// Handle indirect GEMM convolution
@@ -550,15 +558,16 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::prepare(ITensorPack &tensors)
reinterpret_cast<const int32_t *>(c->buffer() + c->info()->offset_first_element_in_bytes()), 0);
}
const ITensor *b_to_use = b;
+
// Pre-pretranspose B if required
- const bool run_pre_pretranspose_b = _gemm_info.transpose_b && !isVarWeightsKernel();
CpuAuxTensorHandler pre_pretransposed_b(
offset_int_vec(PrePretransposedB), _pre_pretransposed_b_info, tensors,
/*pack_inject: no need to inject into tensors*/
false,
/*bypass_alloc: no need to allocate if pre-pretranspose B is not required as this handle will not be used*/
- !run_pre_pretranspose_b);
- if (run_pre_pretranspose_b)
+ !_run_pre_pretranspose_b);
+
+ if (_run_pre_pretranspose_b)
{
ARM_COMPUTE_ERROR_ON(_pre_pretranspose_b == nullptr);
ITensorPack pre_pretranspose_pack{{ACL_SRC, b_to_use}, {ACL_DST, pre_pretransposed_b.get()}};
@@ -567,24 +576,29 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::prepare(ITensorPack &tensors)
}
// Pretranspose B if required
- if (_gemm_kernel_asm->B_pretranspose_required())
+ if (_B_pretranspose_required)
{
// Fixed format kernels need no pretranspose.
ARM_COMPUTE_ERROR_ON(arm_compute::is_fixed_format(
assembly_utils::map_to_arm_compute_weight_format(_gemm_kernel_asm->get_config().weight_format)));
+
const int ldb = b_to_use->info()->strides_in_bytes().y() / b_to_use->info()->element_size();
const auto in1_ptr = reinterpret_cast<const TypeInput *>(b_to_use->buffer() +
b_to_use->info()->offset_first_element_in_bytes());
const int multi_stride_b = b_to_use->info()->strides_in_bytes().z() / b_to_use->info()->element_size();
CpuAuxTensorHandler pretranspose(offset_int_vec(Pretranspose), _pretranspose_info, tensors, false);
+
ARM_COMPUTE_ERROR_ON(pretranspose.get()->buffer() == nullptr);
- run_parallel_pretranspose_B_array<TypeInput, TypeOutput>(_gemm_kernel_asm.get(), pretranspose.get(),
- in1_ptr, ldb, multi_stride_b,
- NEScheduler::get().num_threads());
+
+ const bool kernel_supports_transpose = _gemm_kernel_asm->B_pretranspose_supports_transpose();
+ run_parallel_pretranspose_B_array<TypeInput, TypeOutput>(
+ _gemm_kernel_asm.get(), pretranspose.get(), in1_ptr, ldb, multi_stride_b,
+ NEScheduler::get().num_threads(), _B_pre_pretranspose_required && kernel_supports_transpose);
b->mark_as_unused();
- // Note that we don't need to mark b_to_use as unused, as if it's been assigned to pre_pretransposed_b, its memory will be auto-managed by the handler
+ // Note that we don't need to mark b_to_use as unused, as if it's been assigned to pre_pretransposed_b,
+ // its memory will be auto-managed by the handler
}
if (_gemm_info.method == AsmConvMethod::Indirect)
@@ -640,12 +654,11 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::run(ITensorPack &tensors)
const ITensor *b_to_use = b;
// Pre-pretranspose B if required
- const bool run_pre_pretranspose_b = _gemm_info.transpose_b && !isVarWeightsKernel();
CpuAuxTensorHandler pre_pretransposed_b(
offset_int_vec(PrePretransposedB), _pre_pretransposed_b_info, tensors,
false /*pack_inject: no need to inject into tensors*/,
- !run_pre_pretranspose_b /*bypass_alloc: no need to allocate if pre-pretranspose B is not required as this handle will not be used*/);
- if (b_to_use && !_is_b_constant && run_pre_pretranspose_b)
+ !_run_pre_pretranspose_b /*bypass_alloc: no need to allocate if pre-pretranspose B is not required as this handle will not be used*/);
+ if (b_to_use && !_is_b_constant && _run_pre_pretranspose_b)
{
ARM_COMPUTE_ERROR_ON(_pre_pretranspose_b == nullptr);
ITensorPack pre_pretranspose_pack{{ACL_SRC, b_to_use}, {ACL_DST, pre_pretransposed_b.get()}};
@@ -691,9 +704,10 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::run(ITensorPack &tensors)
}
else
{
- run_parallel_pretranspose_B_array<TypeInput, TypeOutput>(_gemm_kernel_asm.get(), pretranspose.get(),
- b_ptr, ldb, multi_stride_b,
- NEScheduler::get().num_threads());
+ const bool kernel_supports_transpose = _gemm_kernel_asm->B_pretranspose_supports_transpose();
+ run_parallel_pretranspose_B_array<TypeInput, TypeOutput>(
+ _gemm_kernel_asm.get(), pretranspose.get(), b_ptr, ldb, multi_stride_b,
+ NEScheduler::get().num_threads(), _B_pre_pretranspose_required && kernel_supports_transpose);
}
}
}