diff options
author | Aron Virginas-Tar <Aron.Virginas-Tar@arm.com> | 2019-08-05 11:52:05 +0100 |
---|---|---|
committer | Aron Virginas-Tar <Aron.Virginas-Tar@arm.com> | 2019-08-05 11:52:05 +0100 |
commit | 651aafec6235ddbc1177229dedcdd21ffad01616 (patch) | |
tree | e935ddb864f25fb48026957d841bff20bcffb272 /src/backends | |
parent | 198ee400733aa633ddeb867e42cc241a684e9787 (diff) | |
download | armnn-651aafec6235ddbc1177229dedcdd21ffad01616.tar.gz |
IVGCVSW-3611 Report TransposeConvolution2d as unsupported on CpuRef when channel multiplier != 1
Signed-off-by: Aron Virginas-Tar <Aron.Virginas-Tar@arm.com>
Change-Id: I784bbff3f7b6650881d3f70ba7cd1891171195b1
Diffstat (limited to 'src/backends')
-rw-r--r-- | src/backends/reference/RefLayerSupport.cpp | 25 |
1 files changed, 22 insertions, 3 deletions
diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp index df9ba1f255..d42404d25b 100644 --- a/src/backends/reference/RefLayerSupport.cpp +++ b/src/backends/reference/RefLayerSupport.cpp @@ -6,8 +6,10 @@ #include "RefLayerSupport.hpp" #include "RefBackendId.hpp" +#include <DataLayoutIndexed.hpp> #include <InternalTypes.hpp> #include <LayerSupportCommon.hpp> + #include <armnn/Types.hpp> #include <armnn/Descriptors.hpp> @@ -1725,8 +1727,6 @@ bool RefLayerSupport::IsTransposeConvolution2dSupported(const TensorInfo& input, const Optional<TensorInfo>& biases, Optional<std::string&> reasonIfUnsupported) const { - ignore_unused(descriptor); - bool supported = true; std::array<DataType,3> supportedTypes = @@ -1753,7 +1753,8 @@ bool RefLayerSupport::IsTransposeConvolution2dSupported(const TensorInfo& input, if (biases.has_value()) { - std::array<DataType,3> biasesSupportedTypes = { + std::array<DataType,3> biasesSupportedTypes = + { DataType::Float32, DataType::Signed32 }; @@ -1761,6 +1762,24 @@ bool RefLayerSupport::IsTransposeConvolution2dSupported(const TensorInfo& input, "Reference TransposeConvolution2d: biases is not a supported type."); } + // NOTE: Temporary restriction; should be removed as soon as support for channel + // multiplier different from 1 (input channels != output channels) has been added + struct ChannelsAreEqual : public Rule + { + ChannelsAreEqual(const TensorInfo& input, + const TensorInfo& output, + const TransposeConvolution2dDescriptor& descriptor) + { + armnnUtils::DataLayoutIndexed dataLayoutIndexed(descriptor.m_DataLayout); + const unsigned int channelsIndex = dataLayoutIndexed.GetChannelsIndex(); + + m_Res = (input.GetShape()[channelsIndex] == output.GetShape()[channelsIndex]); + } + }; + + supported &= CheckSupportRule(ChannelsAreEqual(input, output, descriptor), reasonIfUnsupported, + "Reference TransposeConvolution2d: inputChannels != outputChannels"); + return supported; } |