diff options
author | giuros01 <giuseppe.rossini@arm.com> | 2019-04-01 13:50:22 +0100 |
---|---|---|
committer | Gian Marco Iodice <gianmarco.iodice@arm.com> | 2019-05-10 08:26:44 +0000 |
commit | 46a49a0a8206f0efa7afd514940e180a88ffd732 (patch) | |
tree | 0ec53af4ef65037e357b1d8f6a1d1f65075659f7 /src/runtime/CL/functions/CLDirectDeconvolutionLayer.cpp | |
parent | 879e8dd2fc8523e4059ba9ced9ea0edb57103778 (diff) | |
download | ComputeLibrary-46a49a0a8206f0efa7afd514940e180a88ffd732.tar.gz |
COMPMID-1635: Optimize CLDeconvolutionLayer - Part III
Change-Id: Id2661e093a669ef3eaf2a5116cd278a80c1d5a89
Signed-off-by: giuros01 <giuseppe.rossini@arm.com>
Reviewed-on: https://review.mlplatform.org/c/935
Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
Comments-Addressed: Gian Marco Iodice <gianmarco.iodice@arm.com>
Reviewed-by: Isabella Gottardi <isabella.gottardi@arm.com>
Comments-Addressed: Isabella Gottardi <isabella.gottardi@arm.com>
Tested-by: Isabella Gottardi <isabella.gottardi@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/runtime/CL/functions/CLDirectDeconvolutionLayer.cpp')
-rw-r--r-- | src/runtime/CL/functions/CLDirectDeconvolutionLayer.cpp | 13 |
1 files changed, 10 insertions, 3 deletions
diff --git a/src/runtime/CL/functions/CLDirectDeconvolutionLayer.cpp b/src/runtime/CL/functions/CLDirectDeconvolutionLayer.cpp index 721054dcf3..6e14e26cbd 100644 --- a/src/runtime/CL/functions/CLDirectDeconvolutionLayer.cpp +++ b/src/runtime/CL/functions/CLDirectDeconvolutionLayer.cpp @@ -28,7 +28,6 @@ #include "arm_compute/core/Validate.h" #include "arm_compute/core/utils/misc/ShapeCalculator.h" #include "arm_compute/runtime/CL/CLScheduler.h" -#include "utils/TypePrinter.h" #include <memory> #include <tuple> @@ -161,8 +160,16 @@ void CLDirectDeconvolutionLayer::configure(ICLTensor *input, ICLTensor *weights, _flip_axis.allocator()->allocate(); _flip_axis.map(true); auto axis_data = reinterpret_cast<uint32_t *>(_flip_axis.buffer()); - axis_data[0] = 0; - axis_data[1] = 1; + if(weights->info()->data_layout() == DataLayout::NHWC) + { + axis_data[0] = 1; + axis_data[1] = 2; + } + else + { + axis_data[0] = 0; + axis_data[1] = 1; + } _flip_axis.unmap(); } |