aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/runtime
diff options
context:
space:
mode:
authorMichalis Spyrou <michalis.spyrou@arm.com>2018-04-25 18:10:13 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:50:15 +0000
commit2b3129ebb9e4366e91de5031d1e1d3759cc42c8e (patch)
tree86b0d4f7870f8f548a68fae43cc32913d0d6dd9e /arm_compute/runtime
parent99d40951df87790fb884ce1c42d5e2a7a0009ee0 (diff)
downloadComputeLibrary-2b3129ebb9e4366e91de5031d1e1d3759cc42c8e.tar.gz
COMPMID-1041 NEON Winograd: update function to use GEMM function
Change-Id: I1ecdf10e02193de7f47a72b75cce0d58a1fa1a1c Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/128411 Tested-by: Jenkins <bsgcomp@arm.com> Reviewed-by: Pablo Tello <pablo.tello@arm.com>
Diffstat (limited to 'arm_compute/runtime')
-rw-r--r--arm_compute/runtime/NEON/functions/NEWinogradLayer.h7
1 files changed, 5 insertions, 2 deletions
diff --git a/arm_compute/runtime/NEON/functions/NEWinogradLayer.h b/arm_compute/runtime/NEON/functions/NEWinogradLayer.h
index 27b1e84201..8010810253 100644
--- a/arm_compute/runtime/NEON/functions/NEWinogradLayer.h
+++ b/arm_compute/runtime/NEON/functions/NEWinogradLayer.h
@@ -27,6 +27,7 @@
#include "arm_compute/runtime/IFunction.h"
#include "arm_compute/core/NEON/INEKernel.h"
+#include "arm_compute/core/NEON/kernels/assembly/arm_gemm.hpp"
#include "arm_compute/core/Types.h"
#include "arm_compute/runtime/CPP/functions/CPPPermute.h"
#include "arm_compute/runtime/MemoryGroup.h"
@@ -93,8 +94,9 @@ public:
NEWinogradLayer &operator=(const NEWinogradLayer &) = delete;
private:
- MemoryGroup _memory_group;
- std::unique_ptr<INEKernel> _batched_gemm_kernel;
+ MemoryGroup _memory_group;
+ std::unique_ptr<arm_gemm::GemmCommon<float, float>> _arm_gemm;
+ std::unique_ptr<INEKernel> _gemm_kernel;
std::unique_ptr<INEKernel> _transform_input_kernel;
std::unique_ptr<INEKernel> _transform_output_kernel;
std::unique_ptr<INEKernel> _transform_weights_kernel;
@@ -109,6 +111,7 @@ private:
Tensor _input_nhwc;
Tensor _output_nhwc;
Tensor _weights_hwio;
+ Tensor _workspace;
const ITensor *_input;
const ITensor *_weights;
ITensor *_output;