diff options
Diffstat (limited to 'src/backends/backendsCommon/WorkloadFactory.cpp')
-rw-r--r-- | src/backends/backendsCommon/WorkloadFactory.cpp | 37 |
1 files changed, 37 insertions, 0 deletions
diff --git a/src/backends/backendsCommon/WorkloadFactory.cpp b/src/backends/backendsCommon/WorkloadFactory.cpp index cca39198e1..2fba3b7059 100644 --- a/src/backends/backendsCommon/WorkloadFactory.cpp +++ b/src/backends/backendsCommon/WorkloadFactory.cpp @@ -796,6 +796,36 @@ bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId, reason); break; } + case LayerType::TransposeConvolution2d: + { + auto cLayer = boost::polymorphic_downcast<const TransposeConvolution2dLayer*>(&layer); + + const TensorInfo input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(), + dataType); + const TensorInfo output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType); + + const TransposeConvolution2dDescriptor& descriptor = cLayer->GetParameters(); + + Optional<TensorInfo> biases; + if (descriptor.m_BiasEnabled) + { + BOOST_ASSERT(cLayer->m_Bias.get() != nullptr); + biases = OverrideDataType(cLayer->m_Bias->GetTensorInfo(), + GetBiasTypeFromWeightsType(dataType)); + } + + BOOST_ASSERT(cLayer->m_Weight.get() != nullptr); + const TensorInfo weights = OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType); + + result = layerSupportObject->IsTransposeConvolution2dSupported(input, + output, + descriptor, + weights, + biases, + reason); + + break; + } default: { BOOST_ASSERT_MSG(false, "WorkloadFactory did not recognise type of layer."); @@ -1098,4 +1128,11 @@ std::unique_ptr<IWorkload> IWorkloadFactory::CreateSwitch(const SwitchQueueDescr return std::unique_ptr<IWorkload>(); } +std::unique_ptr<IWorkload> IWorkloadFactory::CreateTransposeConvolution2d( + const TransposeConvolution2dQueueDescriptor& descriptor, + const WorkloadInfo& info) const +{ + return std::unique_ptr<IWorkload>(); } + +} // namepsace armnn
\ No newline at end of file |