aboutsummaryrefslogtreecommitdiff
path: root/src/backends/backendsCommon/WorkloadFactory.cpp
diff options
context:
space:
mode:
authorAron Virginas-Tar <Aron.Virginas-Tar@arm.com>2019-06-20 14:28:19 +0100
committerMatteo Martincigh <matteo.martincigh@arm.com>2019-06-21 09:54:30 +0000
commit639fb0437d1a5a8a6ea737fed5a16b554dfffead (patch)
tree5b89adc18c1a071d23747a28dcddcfca41e4d815 /src/backends/backendsCommon/WorkloadFactory.cpp
parent713e95c8c531c5cecd804a7cecc8af745917315c (diff)
downloadarmnn-639fb0437d1a5a8a6ea737fed5a16b554dfffead.tar.gz
IVGCVSW-3319 Add frontend support for TransposeConvolution2d Layer
Signed-off-by: Aron Virginas-Tar <Aron.Virginas-Tar@arm.com> Change-Id: Ic06f63f1eff255e697facf319e2ac4c83d782e7c
Diffstat (limited to 'src/backends/backendsCommon/WorkloadFactory.cpp')
-rw-r--r--src/backends/backendsCommon/WorkloadFactory.cpp37
1 files changed, 37 insertions, 0 deletions
diff --git a/src/backends/backendsCommon/WorkloadFactory.cpp b/src/backends/backendsCommon/WorkloadFactory.cpp
index cca39198e1..2fba3b7059 100644
--- a/src/backends/backendsCommon/WorkloadFactory.cpp
+++ b/src/backends/backendsCommon/WorkloadFactory.cpp
@@ -796,6 +796,36 @@ bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
reason);
break;
}
+ case LayerType::TransposeConvolution2d:
+ {
+ auto cLayer = boost::polymorphic_downcast<const TransposeConvolution2dLayer*>(&layer);
+
+ const TensorInfo input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
+ dataType);
+ const TensorInfo output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
+
+ const TransposeConvolution2dDescriptor& descriptor = cLayer->GetParameters();
+
+ Optional<TensorInfo> biases;
+ if (descriptor.m_BiasEnabled)
+ {
+ BOOST_ASSERT(cLayer->m_Bias.get() != nullptr);
+ biases = OverrideDataType(cLayer->m_Bias->GetTensorInfo(),
+ GetBiasTypeFromWeightsType(dataType));
+ }
+
+ BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
+ const TensorInfo weights = OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType);
+
+ result = layerSupportObject->IsTransposeConvolution2dSupported(input,
+ output,
+ descriptor,
+ weights,
+ biases,
+ reason);
+
+ break;
+ }
default:
{
BOOST_ASSERT_MSG(false, "WorkloadFactory did not recognise type of layer.");
@@ -1098,4 +1128,11 @@ std::unique_ptr<IWorkload> IWorkloadFactory::CreateSwitch(const SwitchQueueDescr
return std::unique_ptr<IWorkload>();
}
+std::unique_ptr<IWorkload> IWorkloadFactory::CreateTransposeConvolution2d(
+ const TransposeConvolution2dQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const
+{
+ return std::unique_ptr<IWorkload>();
}
+
+} // namepsace armnn \ No newline at end of file