aboutsummaryrefslogtreecommitdiff
path: root/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp
diff options
context:
space:
mode:
authorGeorgios Pinitas <georgios.pinitas@arm.com>2019-06-21 18:43:12 +0100
committerGeorgios Pinitas <georgios.pinitas@arm.com>2019-07-05 15:30:24 +0000
commit37d080f2f11cfd734104b76512e1fb191486216e (patch)
treed5df067c826aacc0676e7e9557a54b61a9a3b7eb /src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp
parent11de30da8a9f79943255ddba7bb70a66b076673b (diff)
downloadComputeLibrary-37d080f2f11cfd734104b76512e1fb191486216e.tar.gz
COMPMID-2378: Sanitize GEMM configuration for NEON
Change-Id: I7859b82b2059e14685f8792424648ac5eacd67f1 Signed-off-by: Georgios Pinitas <georgios.pinitas@arm.com> Reviewed-on: https://review.mlplatform.org/c/1418 Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com> Reviewed-by: Michalis Spyrou <michalis.spyrou@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp')
-rw-r--r--src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp68
1 files changed, 39 insertions, 29 deletions
diff --git a/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp
index 55e067f52d..2de7d2b279 100644
--- a/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp
+++ b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp
@@ -36,21 +36,22 @@ namespace arm_compute
namespace
{
std::unique_ptr<IFunction> create_function_all_types(const arm_gemm::KernelDescription &gemm_kernel_info,
- const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint,
+ const ITensor *a, const ITensor *b, ITensor *d,
+ float alpha, float beta, const GEMMInfo &gemm_info,
std::shared_ptr<IMemoryManager> memory_manager)
{
- //Note: It's safe to not check for FP16 support because this was already checked in NEGEMMAssemblyDispatch::configure()
+ // Note: It's safe to not check for FP16 support because this was already checked in NEGEMMAssemblyDispatch::configure()
switch(gemm_kernel_info.method)
{
case arm_gemm::GemmMethod::GEMM_INTERLEAVED:
{
- if(!pretranspose_hint)
+ if(!gemm_info.pretranpose_B())
{
return nullptr;
}
auto function = support::cpp14::make_unique<NEGEMMInterleavedWrapper>(memory_manager);
- function->configure(a, b, d, alpha, beta, pretranspose_hint);
+ function->configure(a, b, d, alpha, beta, gemm_info);
return std::move(function);
}
#if defined(__aarch64__)
@@ -59,7 +60,7 @@ std::unique_ptr<IFunction> create_function_all_types(const arm_gemm::KernelDescr
if(gemm_kernel_info.name.find("sgemm_native_16x4") != std::string::npos)
{
auto kernel = support::cpp14::make_unique<NEGEMMNativeWrapperKernel<float, float>>();
- kernel->configure(a, b, d, alpha, beta);
+ kernel->configure(a, b, d, alpha, beta, gemm_info);
auto function = support::cpp14::make_unique<NESimpleAssemblyFunction>();
function->configure(std::move(kernel));
return std::move(function);
@@ -83,9 +84,11 @@ public:
* @param[in] b Input tensor containing the Matrix B.
* @param[out] d Output tensor to store the result of matrix multiplication.
* @param[in] args Matrix multiplication information.
+ * @param[in] gemm_info GEMM meta-data
* @param[in] memory_group Memory group to be used by the function.
*/
- void configure(const ITensor *a, const ITensor *b, ITensor *d, arm_gemm::GemmArgs<TypeOutput> args, MemoryGroup &memory_group);
+ void configure(const ITensor *a, const ITensor *b, ITensor *d, arm_gemm::GemmArgs<TypeOutput> args,
+ const GEMMInfo &gemm_info, MemoryGroup &memory_group);
// Inherited methods overridden:
void run() override;
@@ -123,10 +126,13 @@ private:
Tensor _pretranspose{};
/** Prepared flag */
bool _is_prepared{ false };
+ /** GEMM meta-data */
+ GEMMInfo _gemm_info{};
};
template <typename TypeInput, typename TypeOutput>
-void Fallback<TypeInput, TypeOutput>::configure(const ITensor *a, const ITensor *b, ITensor *d, arm_gemm::GemmArgs<TypeOutput> args, MemoryGroup &memory_group)
+void Fallback<TypeInput, TypeOutput>::configure(const ITensor *a, const ITensor *b, ITensor *d, arm_gemm::GemmArgs<TypeOutput> args,
+ const GEMMInfo &gemm_info, MemoryGroup &memory_group)
{
arm_gemm::GemmConfig gemm_cfg;
const arm_gemm::KernelDescription gemm_kernel_info = arm_gemm::get_gemm_method<TypeInput, TypeOutput>(args);
@@ -168,6 +174,7 @@ void Fallback<TypeInput, TypeOutput>::configure(const ITensor *a, const ITensor
_a = a;
_b = b;
_d = d;
+ _gemm_info = gemm_info;
// Check for pre-transposed support
if(_gemm_kernel_asm->B_pretranspose_required())
{
@@ -222,17 +229,17 @@ void Fallback<TypeInput, TypeOutput>::run()
int ldb = 0;
const int ldd = _d->info()->strides_in_bytes().y() / sizeof(TypeOutput);
- // In the case of NHWC we want to interpret the output shape as 3D. Thus, the batch stride for A is
- // the relevant multiple of the row stride.
- const bool is_nhwc = _a->info()->data_layout() == DataLayout::NHWC;
- const int stride_in_bytes_a = is_nhwc ? _a->info()->strides_in_bytes().y() * _d->info()->dimension(1) : _a->info()->strides_in_bytes().z();
+ const size_t a_batch_idx = _gemm_info.reinterpret_input_as_3d() != 0 ? 3 : 2;
+ const size_t a_multi_idx = a_batch_idx + 1;
+ const size_t d_batch_idx = _gemm_info.depth_output_gemm3d() != 0 ? 3 : 2;
+ const size_t d_multi_idx = d_batch_idx + 1;
- const int batch_stride_a = stride_in_bytes_a / sizeof(TypeInput);
- const int batch_stride_d = _d->info()->strides_in_bytes().z() / sizeof(TypeOutput);
+ const int batch_stride_a = _a->info()->strides_in_bytes()[a_batch_idx] / sizeof(TypeInput);
+ const int batch_stride_d = _d->info()->strides_in_bytes()[d_batch_idx] / sizeof(TypeOutput);
- const int multi_stride_a = _a->info()->strides_in_bytes()[3] / sizeof(TypeInput);
+ const int multi_stride_a = _a->info()->strides_in_bytes()[a_multi_idx] / sizeof(TypeInput);
int multi_stride_b = 0;
- const int multi_stride_d = _d->info()->strides_in_bytes()[3] / sizeof(TypeOutput);
+ const int multi_stride_d = _d->info()->strides_in_bytes()[d_multi_idx] / sizeof(TypeOutput);
const auto in0_ptr = reinterpret_cast<const TypeInput *>(_a->buffer() + _a->info()->offset_first_element_in_bytes());
const TypeInput *in1_ptr = nullptr;
@@ -270,24 +277,27 @@ void Fallback<TypeInput, TypeOutput>::run()
}
template <typename TypeInput, typename TypeOutput>
-void create_function_or_arm_gemm(std::unique_ptr<IFunction> &acl_function, std::unique_ptr<NEGEMMAssemblyDispatch::IFallback> &arm_gemm, MemoryGroup &memory_group, const ITensor *a, const ITensor *b,
- ITensor *d, float alpha, float beta, bool pretranspose_hint, std::shared_ptr<IMemoryManager> memory_manager)
+void create_function_or_arm_gemm(std::unique_ptr<IFunction> &acl_function,
+ std::unique_ptr<NEGEMMAssemblyDispatch::IFallback> &arm_gemm,
+ MemoryGroup &memory_group, const ITensor *a, const ITensor *b,
+ ITensor *d, float alpha, float beta, const GEMMInfo &gemm_info,
+ std::shared_ptr<IMemoryManager> memory_manager)
{
- INEGEMMWrapperKernel::Params p = INEGEMMWrapperKernel::extract_parameters(a, b, d);
+ INEGEMMWrapperKernel::Params p = INEGEMMWrapperKernel::extract_parameters(a, b, d, gemm_info);
const CPUInfo &ci = NEScheduler::get().cpu_info();
unsigned int num_threads = NEScheduler::get().num_threads();
- arm_gemm::GemmArgs<TypeOutput> args(&ci, p.M, p.N, p.K, p.batches, p.multis, false, false, alpha, beta, num_threads, pretranspose_hint);
+ arm_gemm::GemmArgs<TypeOutput> args(&ci, p.M, p.N, p.K, p.batches, p.multis, false, false, alpha, beta, num_threads, gemm_info.pretranpose_B());
//Try to create an ACL function:
- acl_function = create_function_all_types(arm_gemm::get_gemm_method<TypeInput, TypeOutput>(args), a, b, d, alpha, beta, pretranspose_hint, std::move(memory_manager));
+ acl_function = create_function_all_types(arm_gemm::get_gemm_method<TypeInput, TypeOutput>(args), a, b, d, alpha, beta, gemm_info, std::move(memory_manager));
//If we still don't have an ACL function:
if(acl_function == nullptr)
{
//Fallback onto arm_gemm function if ACL doesn't support this method.
auto fallback = support::cpp14::make_unique<Fallback<TypeInput, TypeOutput>>();
- fallback->configure(a, b, d, args, memory_group);
+ fallback->configure(a, b, d, args, gemm_info, memory_group);
arm_gemm = std::move(fallback);
}
}
@@ -299,11 +309,11 @@ NEGEMMAssemblyDispatch::NEGEMMAssemblyDispatch(std::shared_ptr<IMemoryManager> m
{
}
-Status NEGEMMAssemblyDispatch::validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *d, float alpha, float beta, bool pretranspose_hint)
+Status NEGEMMAssemblyDispatch::validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *d, float alpha, float beta, const GEMMInfo &gemm_info)
{
ARM_COMPUTE_UNUSED(alpha);
ARM_COMPUTE_UNUSED(beta);
- ARM_COMPUTE_UNUSED(pretranspose_hint);
+ ARM_COMPUTE_UNUSED(gemm_info);
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(a, b, d);
ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(a);
#ifndef __aarch64__
@@ -319,14 +329,14 @@ Status NEGEMMAssemblyDispatch::validate(const ITensorInfo *a, const ITensorInfo
return Status{};
}
-void NEGEMMAssemblyDispatch::configure(const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint)
+void NEGEMMAssemblyDispatch::configure(const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, const GEMMInfo &gemm_info)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(a);
ARM_COMPUTE_ERROR_ON_NULLPTR(b);
ARM_COMPUTE_ERROR_ON_NULLPTR(d);
//If we don't support a combination of data types, silently return: it is the caller's responsibility to check if configure() was successful via is_configured()
- if(!NEGEMMAssemblyDispatch::validate(a->info(), b->info(), d->info(), alpha, beta, pretranspose_hint))
+ if(!NEGEMMAssemblyDispatch::validate(a->info(), b->info(), d->info(), alpha, beta, gemm_info))
{
return;
}
@@ -334,20 +344,20 @@ void NEGEMMAssemblyDispatch::configure(const ITensor *a, const ITensor *b, ITens
switch(a->info()->data_type())
{
case DataType::F32:
- create_function_or_arm_gemm<float, float>(_function, _arm_gemm, _memory_group, a, b, d, alpha, beta, pretranspose_hint, _memory_manager);
+ create_function_or_arm_gemm<float, float>(_function, _arm_gemm, _memory_group, a, b, d, alpha, beta, gemm_info, _memory_manager);
break;
#ifdef __aarch64__
case DataType::U8:
case DataType::QASYMM8:
- create_function_or_arm_gemm<uint8_t, uint32_t>(_function, _arm_gemm, _memory_group, a, b, d, alpha, beta, pretranspose_hint, _memory_manager);
+ create_function_or_arm_gemm<uint8_t, uint32_t>(_function, _arm_gemm, _memory_group, a, b, d, alpha, beta, gemm_info, _memory_manager);
break;
case DataType::S8:
- create_function_or_arm_gemm<int8_t, int32_t>(_function, _arm_gemm, _memory_group, a, b, d, alpha, beta, pretranspose_hint, _memory_manager);
+ create_function_or_arm_gemm<int8_t, int32_t>(_function, _arm_gemm, _memory_group, a, b, d, alpha, beta, gemm_info, _memory_manager);
break;
#endif /* __aarch64__ */
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
case DataType::F16:
- create_function_or_arm_gemm<float16_t, float16_t>(_function, _arm_gemm, _memory_group, a, b, d, alpha, beta, pretranspose_hint, _memory_manager);
+ create_function_or_arm_gemm<float16_t, float16_t>(_function, _arm_gemm, _memory_group, a, b, d, alpha, beta, gemm_info, _memory_manager);
break;
#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
default: