diff options
-rw-r--r-- | src/backends/reference/RefLayerSupport.cpp | 25 |
1 files changed, 22 insertions, 3 deletions
diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp index df9ba1f255..d42404d25b 100644 --- a/src/backends/reference/RefLayerSupport.cpp +++ b/src/backends/reference/RefLayerSupport.cpp @@ -6,8 +6,10 @@ #include "RefLayerSupport.hpp" #include "RefBackendId.hpp" +#include <DataLayoutIndexed.hpp> #include <InternalTypes.hpp> #include <LayerSupportCommon.hpp> + #include <armnn/Types.hpp> #include <armnn/Descriptors.hpp> @@ -1725,8 +1727,6 @@ bool RefLayerSupport::IsTransposeConvolution2dSupported(const TensorInfo& input, const Optional<TensorInfo>& biases, Optional<std::string&> reasonIfUnsupported) const { - ignore_unused(descriptor); - bool supported = true; std::array<DataType,3> supportedTypes = @@ -1753,7 +1753,8 @@ bool RefLayerSupport::IsTransposeConvolution2dSupported(const TensorInfo& input, if (biases.has_value()) { - std::array<DataType,3> biasesSupportedTypes = { + std::array<DataType,3> biasesSupportedTypes = + { DataType::Float32, DataType::Signed32 }; @@ -1761,6 +1762,24 @@ bool RefLayerSupport::IsTransposeConvolution2dSupported(const TensorInfo& input, "Reference TransposeConvolution2d: biases is not a supported type."); } + // NOTE: Temporary restriction; should be removed as soon as support for channel + // multiplier different from 1 (input channels != output channels) has been added + struct ChannelsAreEqual : public Rule + { + ChannelsAreEqual(const TensorInfo& input, + const TensorInfo& output, + const TransposeConvolution2dDescriptor& descriptor) + { + armnnUtils::DataLayoutIndexed dataLayoutIndexed(descriptor.m_DataLayout); + const unsigned int channelsIndex = dataLayoutIndexed.GetChannelsIndex(); + + m_Res = (input.GetShape()[channelsIndex] == output.GetShape()[channelsIndex]); + } + }; + + supported &= CheckSupportRule(ChannelsAreEqual(input, output, descriptor), reasonIfUnsupported, + "Reference TransposeConvolution2d: inputChannels != outputChannels"); + return supported; } |