From 651aafec6235ddbc1177229dedcdd21ffad01616 Mon Sep 17 00:00:00 2001 From: Aron Virginas-Tar Date: Mon, 5 Aug 2019 11:52:05 +0100 Subject: IVGCVSW-3611 Report TransposeConvolution2d as unsupported on CpuRef when channel multiplier != 1 Signed-off-by: Aron Virginas-Tar Change-Id: I784bbff3f7b6650881d3f70ba7cd1891171195b1 --- src/backends/reference/RefLayerSupport.cpp | 25 ++++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp index df9ba1f255..d42404d25b 100644 --- a/src/backends/reference/RefLayerSupport.cpp +++ b/src/backends/reference/RefLayerSupport.cpp @@ -6,8 +6,10 @@ #include "RefLayerSupport.hpp" #include "RefBackendId.hpp" +#include #include #include + #include #include @@ -1725,8 +1727,6 @@ bool RefLayerSupport::IsTransposeConvolution2dSupported(const TensorInfo& input, const Optional& biases, Optional reasonIfUnsupported) const { - ignore_unused(descriptor); - bool supported = true; std::array supportedTypes = @@ -1753,7 +1753,8 @@ bool RefLayerSupport::IsTransposeConvolution2dSupported(const TensorInfo& input, if (biases.has_value()) { - std::array biasesSupportedTypes = { + std::array biasesSupportedTypes = + { DataType::Float32, DataType::Signed32 }; @@ -1761,6 +1762,24 @@ bool RefLayerSupport::IsTransposeConvolution2dSupported(const TensorInfo& input, "Reference TransposeConvolution2d: biases is not a supported type."); } + // NOTE: Temporary restriction; should be removed as soon as support for channel + // multiplier different from 1 (input channels != output channels) has been added + struct ChannelsAreEqual : public Rule + { + ChannelsAreEqual(const TensorInfo& input, + const TensorInfo& output, + const TransposeConvolution2dDescriptor& descriptor) + { + armnnUtils::DataLayoutIndexed dataLayoutIndexed(descriptor.m_DataLayout); + const unsigned int channelsIndex = dataLayoutIndexed.GetChannelsIndex(); + + m_Res = (input.GetShape()[channelsIndex] == output.GetShape()[channelsIndex]); + } + }; + + supported &= CheckSupportRule(ChannelsAreEqual(input, output, descriptor), reasonIfUnsupported, + "Reference TransposeConvolution2d: inputChannels != outputChannels"); + return supported; } -- cgit v1.2.1