diff options
author | Aron Virginas-Tar <Aron.Virginas-Tar@arm.com> | 2019-06-20 14:28:19 +0100 |
---|---|---|
committer | Matteo Martincigh <matteo.martincigh@arm.com> | 2019-06-21 09:54:30 +0000 |
commit | 639fb0437d1a5a8a6ea737fed5a16b554dfffead (patch) | |
tree | 5b89adc18c1a071d23747a28dcddcfca41e4d815 /src/backends/backendsCommon/WorkloadFactory.cpp | |
parent | 713e95c8c531c5cecd804a7cecc8af745917315c (diff) | |
download | armnn-639fb0437d1a5a8a6ea737fed5a16b554dfffead.tar.gz |
IVGCVSW-3319 Add frontend support for TransposeConvolution2d Layer
Signed-off-by: Aron Virginas-Tar <Aron.Virginas-Tar@arm.com>
Change-Id: Ic06f63f1eff255e697facf319e2ac4c83d782e7c
Diffstat (limited to 'src/backends/backendsCommon/WorkloadFactory.cpp')
-rw-r--r-- | src/backends/backendsCommon/WorkloadFactory.cpp | 37 |
1 files changed, 37 insertions, 0 deletions
diff --git a/src/backends/backendsCommon/WorkloadFactory.cpp b/src/backends/backendsCommon/WorkloadFactory.cpp index cca39198e1..2fba3b7059 100644 --- a/src/backends/backendsCommon/WorkloadFactory.cpp +++ b/src/backends/backendsCommon/WorkloadFactory.cpp @@ -796,6 +796,36 @@ bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId, reason); break; } + case LayerType::TransposeConvolution2d: + { + auto cLayer = boost::polymorphic_downcast<const TransposeConvolution2dLayer*>(&layer); + + const TensorInfo input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(), + dataType); + const TensorInfo output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType); + + const TransposeConvolution2dDescriptor& descriptor = cLayer->GetParameters(); + + Optional<TensorInfo> biases; + if (descriptor.m_BiasEnabled) + { + BOOST_ASSERT(cLayer->m_Bias.get() != nullptr); + biases = OverrideDataType(cLayer->m_Bias->GetTensorInfo(), + GetBiasTypeFromWeightsType(dataType)); + } + + BOOST_ASSERT(cLayer->m_Weight.get() != nullptr); + const TensorInfo weights = OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType); + + result = layerSupportObject->IsTransposeConvolution2dSupported(input, + output, + descriptor, + weights, + biases, + reason); + + break; + } default: { BOOST_ASSERT_MSG(false, "WorkloadFactory did not recognise type of layer."); @@ -1098,4 +1128,11 @@ std::unique_ptr<IWorkload> IWorkloadFactory::CreateSwitch(const SwitchQueueDescr return std::unique_ptr<IWorkload>(); } +std::unique_ptr<IWorkload> IWorkloadFactory::CreateTransposeConvolution2d( + const TransposeConvolution2dQueueDescriptor& descriptor, + const WorkloadInfo& info) const +{ + return std::unique_ptr<IWorkload>(); } + +} // namepsace armnn
\ No newline at end of file |