diff options
Diffstat (limited to 'src/core/NEON/kernels/assembly/INEGEMMWrapperKernel.cpp')
-rw-r--r-- | src/core/NEON/kernels/assembly/INEGEMMWrapperKernel.cpp | 25 |
1 files changed, 17 insertions, 8 deletions
diff --git a/src/core/NEON/kernels/assembly/INEGEMMWrapperKernel.cpp b/src/core/NEON/kernels/assembly/INEGEMMWrapperKernel.cpp index 0fc3610014..d00f204b81 100644 --- a/src/core/NEON/kernels/assembly/INEGEMMWrapperKernel.cpp +++ b/src/core/NEON/kernels/assembly/INEGEMMWrapperKernel.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018 ARM Limited. + * Copyright (c) 2018-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -33,11 +33,11 @@ using namespace arm_compute; INEGEMMWrapperKernel::INEGEMMWrapperKernel() - : _a(nullptr), _b(nullptr), _c(nullptr), _params(), _window3d(), _window_shape() + : _a(nullptr), _b(nullptr), _c(nullptr), _params(), _gemm_info(), _window3d(), _window_shape() { } -INEGEMMWrapperKernel::Params INEGEMMWrapperKernel::extract_parameters(const ITensor *a, const ITensor *b, const ITensor *c) +INEGEMMWrapperKernel::Params INEGEMMWrapperKernel::extract_parameters(const ITensor *a, const ITensor *b, const ITensor *c, const GEMMInfo &gemm_info) { Params p; @@ -45,21 +45,30 @@ INEGEMMWrapperKernel::Params INEGEMMWrapperKernel::extract_parameters(const ITen ARM_COMPUTE_ERROR_ON_NULLPTR(b); ARM_COMPUTE_ERROR_ON_NULLPTR(c); + // Initalize params p.M = c->info()->tensor_shape().y(); p.N = c->info()->tensor_shape().x(); p.K = a->info()->tensor_shape().x(); p.multis = b->info()->tensor_shape().z(); p.batches = c->info()->tensor_shape().total_size_upper(2) / p.multis; //COMPMID-1423: Agree on and document the layout of gemm inputs/outputs + // Update M in case of GEMM3D for output + if(gemm_info.depth_output_gemm3d() != 0) + { + p.M = c->info()->tensor_shape().y() * c->info()->tensor_shape().z(); + p.batches = c->info()->tensor_shape().total_size_upper(3) / p.multis; + } + return p; } -void INEGEMMWrapperKernel::configure(const ITensor *a, const ITensor *b, ITensor *c, float alpha, float beta) +void INEGEMMWrapperKernel::configure(const ITensor *a, const ITensor *b, ITensor *c, float alpha, float beta, const GEMMInfo &gemm_info) { - _params = extract_parameters(a, b, c); - _a = a; - _b = b; - _c = c; + _gemm_info = gemm_info; + _params = extract_parameters(a, b, c, gemm_info); + _a = a; + _b = b; + _c = c; _window3d = configure_internal(alpha, beta); _window_shape = _window3d.shape(); |