diff options
Diffstat (limited to 'src/runtime/CL/functions/CLDirectDeconvolutionLayer.cpp')
-rw-r--r-- | src/runtime/CL/functions/CLDirectDeconvolutionLayer.cpp | 13 |
1 files changed, 10 insertions, 3 deletions
diff --git a/src/runtime/CL/functions/CLDirectDeconvolutionLayer.cpp b/src/runtime/CL/functions/CLDirectDeconvolutionLayer.cpp index 721054dcf3..6e14e26cbd 100644 --- a/src/runtime/CL/functions/CLDirectDeconvolutionLayer.cpp +++ b/src/runtime/CL/functions/CLDirectDeconvolutionLayer.cpp @@ -28,7 +28,6 @@ #include "arm_compute/core/Validate.h" #include "arm_compute/core/utils/misc/ShapeCalculator.h" #include "arm_compute/runtime/CL/CLScheduler.h" -#include "utils/TypePrinter.h" #include <memory> #include <tuple> @@ -161,8 +160,16 @@ void CLDirectDeconvolutionLayer::configure(ICLTensor *input, ICLTensor *weights, _flip_axis.allocator()->allocate(); _flip_axis.map(true); auto axis_data = reinterpret_cast<uint32_t *>(_flip_axis.buffer()); - axis_data[0] = 0; - axis_data[1] = 1; + if(weights->info()->data_layout() == DataLayout::NHWC) + { + axis_data[0] = 1; + axis_data[1] = 2; + } + else + { + axis_data[0] = 0; + axis_data[1] = 1; + } _flip_axis.unmap(); } |