From 98180eff3c1feed19576d82ceac06d476235d973 Mon Sep 17 00:00:00 2001 From: Aron Virginas-Tar Date: Wed, 26 Jun 2019 15:02:47 +0100 Subject: IVGCVSW-3324 Add end-to-end tests for TransposeConvolution2d on CpuRef * Added one end-to-end test for all supported data types and data layout * Implemented RefLayerSupport::IsTransposeConvolution2dSupported() * Fixed formula used in TransposeConvolution2dLayer::InferOutputShapes() Signed-off-by: Aron Virginas-Tar Change-Id: If1ba3c226ecfa17f7fceffae857f39297c6433f2 --- src/backends/reference/RefLayerSupport.cpp | 46 ++++++++++++++++++++++++ src/backends/reference/RefLayerSupport.hpp | 8 +++++ src/backends/reference/RefWorkloadFactory.cpp | 4 +++ src/backends/reference/test/RefEndToEndTests.cpp | 38 ++++++++++++++++++++ 4 files changed, 96 insertions(+) (limited to 'src/backends/reference') diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp index 06d9e1bff9..429993a55f 100644 --- a/src/backends/reference/RefLayerSupport.cpp +++ b/src/backends/reference/RefLayerSupport.cpp @@ -1466,4 +1466,50 @@ bool RefLayerSupport::IsPreluSupported(const TensorInfo& input, return supported; } +bool RefLayerSupport::IsTransposeConvolution2dSupported(const TensorInfo& input, + const TensorInfo& output, + const TransposeConvolution2dDescriptor& descriptor, + const TensorInfo& weights, + const Optional& biases, + Optional reasonIfUnsupported) const +{ + ignore_unused(descriptor); + + bool supported = true; + + std::array supportedTypes = + { + DataType::Float32, + DataType::QuantisedAsymm8, + DataType::QuantisedSymm16 + }; + + supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported, + "Reference TransposeConvolution2d: input is not a supported type."); + + supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported, + "Reference TransposeConvolution2d: output is not a supported type."); + + supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported, + "Reference TransposeConvolution2d: weights is not a supported type."); + + supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported, + "Reference TransposeConvolution2d: input and output types mismatched."); + + supported &= CheckSupportRule(TypesAreEqual(input, weights), reasonIfUnsupported, + "Reference TransposeConvolution2d: input and weights types mismatched."); + + if (biases.has_value()) + { + std::array biasesSupportedTypes = { + DataType::Float32, + DataType::Signed32 + }; + supported &= CheckSupportRule(TypeAnyOf(biases.value(), biasesSupportedTypes), reasonIfUnsupported, + "Reference TransposeConvolution2d: biases is not a supported type."); + } + + return supported; +} + } // namespace armnn diff --git a/src/backends/reference/RefLayerSupport.hpp b/src/backends/reference/RefLayerSupport.hpp index 5d241492c2..9c397fe66b 100644 --- a/src/backends/reference/RefLayerSupport.hpp +++ b/src/backends/reference/RefLayerSupport.hpp @@ -266,6 +266,14 @@ public: const TensorInfo& alpha, const TensorInfo& output, Optional reasonIfUnsupported = EmptyOptional()) const override; + + bool IsTransposeConvolution2dSupported( + const TensorInfo& input, + const TensorInfo& output, + const TransposeConvolution2dDescriptor& descriptor, + const TensorInfo& weights, + const Optional& biases, + Optional reasonIfUnsupported = EmptyOptional()) const override; }; } // namespace armnn diff --git a/src/backends/reference/RefWorkloadFactory.cpp b/src/backends/reference/RefWorkloadFactory.cpp index 5ede8b3f02..d906f93a38 100644 --- a/src/backends/reference/RefWorkloadFactory.cpp +++ b/src/backends/reference/RefWorkloadFactory.cpp @@ -468,6 +468,10 @@ std::unique_ptr RefWorkloadFactory::CreateTransposeConvolution2d( const TransposeConvolution2dQueueDescriptor& descriptor, const WorkloadInfo& info) const { + if (IsFloat16(info)) + { + return MakeWorkload(descriptor, info); + } return std::make_unique(descriptor, info); } diff --git a/src/backends/reference/test/RefEndToEndTests.cpp b/src/backends/reference/test/RefEndToEndTests.cpp index 4d56952e27..a528a54cd2 100644 --- a/src/backends/reference/test/RefEndToEndTests.cpp +++ b/src/backends/reference/test/RefEndToEndTests.cpp @@ -13,6 +13,7 @@ #include #include #include +#include #include #include @@ -930,4 +931,41 @@ BOOST_AUTO_TEST_CASE(RefSplitter4dDim3EndToEndUint8Test) Splitter4dDim3EndToEnd(defaultBackends); } +// TransposeConvolution2d +BOOST_AUTO_TEST_CASE(RefTransposeConvolution2dEndToEndFloatNchwTest) +{ + TransposeConvolution2dEndToEnd( + defaultBackends, armnn::DataLayout::NCHW); +} + +BOOST_AUTO_TEST_CASE(RefTransposeConvolution2dEndToEndUint8NchwTest) +{ + TransposeConvolution2dEndToEnd( + defaultBackends, armnn::DataLayout::NCHW); +} + +BOOST_AUTO_TEST_CASE(RefTransposeConvolution2dEndToEndInt16NchwTest) +{ + TransposeConvolution2dEndToEnd( + defaultBackends, armnn::DataLayout::NCHW); +} + +BOOST_AUTO_TEST_CASE(RefTransposeConvolution2dEndToEndFloatNhwcTest) +{ + TransposeConvolution2dEndToEnd( + defaultBackends, armnn::DataLayout::NHWC); +} + +BOOST_AUTO_TEST_CASE(RefTransposeConvolution2dEndToEndUint8NhwcTest) +{ + TransposeConvolution2dEndToEnd( + defaultBackends, armnn::DataLayout::NHWC); +} + +BOOST_AUTO_TEST_CASE(RefTransposeConvolution2dEndToEndInt16NhwcTest) +{ + TransposeConvolution2dEndToEnd( + defaultBackends, armnn::DataLayout::NHWC); +} + BOOST_AUTO_TEST_SUITE_END() -- cgit v1.2.1