aboutsummaryrefslogtreecommitdiff
path: root/src/runtime/CL/functions/CLDirectDeconvolutionLayer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/runtime/CL/functions/CLDirectDeconvolutionLayer.cpp')
-rw-r--r--src/runtime/CL/functions/CLDirectDeconvolutionLayer.cpp13
1 files changed, 10 insertions, 3 deletions
diff --git a/src/runtime/CL/functions/CLDirectDeconvolutionLayer.cpp b/src/runtime/CL/functions/CLDirectDeconvolutionLayer.cpp
index 721054dcf3..6e14e26cbd 100644
--- a/src/runtime/CL/functions/CLDirectDeconvolutionLayer.cpp
+++ b/src/runtime/CL/functions/CLDirectDeconvolutionLayer.cpp
@@ -28,7 +28,6 @@
#include "arm_compute/core/Validate.h"
#include "arm_compute/core/utils/misc/ShapeCalculator.h"
#include "arm_compute/runtime/CL/CLScheduler.h"
-#include "utils/TypePrinter.h"
#include <memory>
#include <tuple>
@@ -161,8 +160,16 @@ void CLDirectDeconvolutionLayer::configure(ICLTensor *input, ICLTensor *weights,
_flip_axis.allocator()->allocate();
_flip_axis.map(true);
auto axis_data = reinterpret_cast<uint32_t *>(_flip_axis.buffer());
- axis_data[0] = 0;
- axis_data[1] = 1;
+ if(weights->info()->data_layout() == DataLayout::NHWC)
+ {
+ axis_data[0] = 1;
+ axis_data[1] = 2;
+ }
+ else
+ {
+ axis_data[0] = 0;
+ axis_data[1] = 1;
+ }
_flip_axis.unmap();
}