aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/assembly/INEGEMMWrapperKernel.cpp
diff options
context:
space:
mode:
authorGeorgios Pinitas <georgios.pinitas@arm.com>2019-06-21 18:43:12 +0100
committerGeorgios Pinitas <georgios.pinitas@arm.com>2019-07-05 15:30:24 +0000
commit37d080f2f11cfd734104b76512e1fb191486216e (patch)
treed5df067c826aacc0676e7e9557a54b61a9a3b7eb /src/core/NEON/kernels/assembly/INEGEMMWrapperKernel.cpp
parent11de30da8a9f79943255ddba7bb70a66b076673b (diff)
downloadComputeLibrary-37d080f2f11cfd734104b76512e1fb191486216e.tar.gz
COMPMID-2378: Sanitize GEMM configuration for NEON
Change-Id: I7859b82b2059e14685f8792424648ac5eacd67f1 Signed-off-by: Georgios Pinitas <georgios.pinitas@arm.com> Reviewed-on: https://review.mlplatform.org/c/1418 Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com> Reviewed-by: Michalis Spyrou <michalis.spyrou@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/core/NEON/kernels/assembly/INEGEMMWrapperKernel.cpp')
-rw-r--r--src/core/NEON/kernels/assembly/INEGEMMWrapperKernel.cpp25
1 files changed, 17 insertions, 8 deletions
diff --git a/src/core/NEON/kernels/assembly/INEGEMMWrapperKernel.cpp b/src/core/NEON/kernels/assembly/INEGEMMWrapperKernel.cpp
index 0fc3610014..d00f204b81 100644
--- a/src/core/NEON/kernels/assembly/INEGEMMWrapperKernel.cpp
+++ b/src/core/NEON/kernels/assembly/INEGEMMWrapperKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018 ARM Limited.
+ * Copyright (c) 2018-2019 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -33,11 +33,11 @@
using namespace arm_compute;
INEGEMMWrapperKernel::INEGEMMWrapperKernel()
- : _a(nullptr), _b(nullptr), _c(nullptr), _params(), _window3d(), _window_shape()
+ : _a(nullptr), _b(nullptr), _c(nullptr), _params(), _gemm_info(), _window3d(), _window_shape()
{
}
-INEGEMMWrapperKernel::Params INEGEMMWrapperKernel::extract_parameters(const ITensor *a, const ITensor *b, const ITensor *c)
+INEGEMMWrapperKernel::Params INEGEMMWrapperKernel::extract_parameters(const ITensor *a, const ITensor *b, const ITensor *c, const GEMMInfo &gemm_info)
{
Params p;
@@ -45,21 +45,30 @@ INEGEMMWrapperKernel::Params INEGEMMWrapperKernel::extract_parameters(const ITen
ARM_COMPUTE_ERROR_ON_NULLPTR(b);
ARM_COMPUTE_ERROR_ON_NULLPTR(c);
+ // Initalize params
p.M = c->info()->tensor_shape().y();
p.N = c->info()->tensor_shape().x();
p.K = a->info()->tensor_shape().x();
p.multis = b->info()->tensor_shape().z();
p.batches = c->info()->tensor_shape().total_size_upper(2) / p.multis; //COMPMID-1423: Agree on and document the layout of gemm inputs/outputs
+ // Update M in case of GEMM3D for output
+ if(gemm_info.depth_output_gemm3d() != 0)
+ {
+ p.M = c->info()->tensor_shape().y() * c->info()->tensor_shape().z();
+ p.batches = c->info()->tensor_shape().total_size_upper(3) / p.multis;
+ }
+
return p;
}
-void INEGEMMWrapperKernel::configure(const ITensor *a, const ITensor *b, ITensor *c, float alpha, float beta)
+void INEGEMMWrapperKernel::configure(const ITensor *a, const ITensor *b, ITensor *c, float alpha, float beta, const GEMMInfo &gemm_info)
{
- _params = extract_parameters(a, b, c);
- _a = a;
- _b = b;
- _c = c;
+ _gemm_info = gemm_info;
+ _params = extract_parameters(a, b, c, gemm_info);
+ _a = a;
+ _b = b;
+ _c = c;
_window3d = configure_internal(alpha, beta);
_window_shape = _window3d.shape();