aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/backends/backendsCommon/test/TransposeConvolution2dTestImpl.hpp3
-rw-r--r--src/backends/cl/test/ClLayerTests.cpp13
-rw-r--r--src/backends/reference/workloads/TransposeConvolution2d.cpp2
3 files changed, 16 insertions, 2 deletions
diff --git a/src/backends/backendsCommon/test/TransposeConvolution2dTestImpl.hpp b/src/backends/backendsCommon/test/TransposeConvolution2dTestImpl.hpp
index 9140c19383..64caa3fce1 100644
--- a/src/backends/backendsCommon/test/TransposeConvolution2dTestImpl.hpp
+++ b/src/backends/backendsCommon/test/TransposeConvolution2dTestImpl.hpp
@@ -493,7 +493,8 @@ LayerTestResult<T, 4> MultiChannelTransposeConvolution2dTest(
TensorShape inputShape = MakeTensorShape(1, 1, 2, 2, layout);
TensorShape outputShape = MakeTensorShape(1, 2, 5, 5, layout);
- TensorShape weightsShape = MakeTensorShape(1, 2, 3, 3, layout);
+ // OIHW for NCHW; OHWI for NHWC
+ TensorShape weightsShape = MakeTensorShape(2, 1, 3, 3, layout);
TensorShape biasesShape = { 2 };
TensorInfo inputInfo(inputShape, ArmnnType);
diff --git a/src/backends/cl/test/ClLayerTests.cpp b/src/backends/cl/test/ClLayerTests.cpp
index d3f39219f3..8a5435b83c 100644
--- a/src/backends/cl/test/ClLayerTests.cpp
+++ b/src/backends/cl/test/ClLayerTests.cpp
@@ -760,6 +760,19 @@ ARMNN_AUTO_TEST_CASE(UnbiasedStridedTransposeConvolution2dUint8Nhwc,
true,
DataLayout::NHWC)
+ARMNN_AUTO_TEST_CASE(MultiChannelTransposeConvolution2dFloatNchw,
+ MultiChannelTransposeConvolution2dTest<DataType::Float32, DataType::Float32>,
+ DataLayout::NCHW)
+ARMNN_AUTO_TEST_CASE(MultiChannelTransposeConvolution2dFloatNhwc,
+ MultiChannelTransposeConvolution2dTest<DataType::Float32, DataType::Float32>,
+ DataLayout::NHWC)
+ARMNN_AUTO_TEST_CASE(MultiChannelTransposeConvolution2dUint8Nchw,
+ MultiChannelTransposeConvolution2dTest<DataType::QuantisedAsymm8, DataType::Signed32>,
+ DataLayout::NCHW)
+ARMNN_AUTO_TEST_CASE(MultiChannelTransposeConvolution2dUint8Nhwc,
+ MultiChannelTransposeConvolution2dTest<DataType::QuantisedAsymm8, DataType::Signed32>,
+ DataLayout::NHWC)
+
// ============================================================================
// COMPARE tests
diff --git a/src/backends/reference/workloads/TransposeConvolution2d.cpp b/src/backends/reference/workloads/TransposeConvolution2d.cpp
index acbfe0cc90..52cc18c17a 100644
--- a/src/backends/reference/workloads/TransposeConvolution2d.cpp
+++ b/src/backends/reference/workloads/TransposeConvolution2d.cpp
@@ -83,7 +83,7 @@ void TransposeConvolution2dImpl(const TransposeConvolution2dDescriptor& descript
inputDecoder[inputIndex];
const unsigned int weightsIndex =
- dataLayoutIndexed.GetIndex(weightsShape, batch, dOutput, yWeights, xWeights);
+ dataLayoutIndexed.GetIndex(weightsShape, dOutput, dInput, yWeights, xWeights);
weightsDecoder[weightsIndex];
const unsigned int outputIndex =