diff options
Diffstat (limited to 'src/backends/backendsCommon/WorkloadFactory.cpp')
-rw-r--r-- | src/backends/backendsCommon/WorkloadFactory.cpp | 38 |
1 files changed, 35 insertions, 3 deletions
diff --git a/src/backends/backendsCommon/WorkloadFactory.cpp b/src/backends/backendsCommon/WorkloadFactory.cpp index 00263eca04..666f83de71 100644 --- a/src/backends/backendsCommon/WorkloadFactory.cpp +++ b/src/backends/backendsCommon/WorkloadFactory.cpp @@ -225,14 +225,13 @@ bool IWorkloadFactory::IsLayerConfigurationSupported(const BackendId& backendId, const TensorInfo output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType); ARMNN_ASSERT(cLayer->m_Weight.get() != nullptr); - const Convolution2dDescriptor& descriptor = cLayer->GetParameters(); + const Convolution2dDescriptor& descriptor = cLayer->GetParameters(); // Construct optional biases object based on the value of m_BiasEnabled Optional<TensorInfo> biases; if (descriptor.m_BiasEnabled) { - biases = - OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType)); + biases = OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType)); } result = layerSupportObject.IsConvolution2dSupported( @@ -244,6 +243,33 @@ bool IWorkloadFactory::IsLayerConfigurationSupported(const BackendId& backendId, reason); break; } + case LayerType::Convolution3d: + { + auto cLayer = PolymorphicDowncast<const Convolution3dLayer*>(&layer); + + const TensorInfo input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(), + dataType); + const TensorInfo output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType); + ARMNN_ASSERT(cLayer->m_Weight.get() != nullptr); + + const Convolution3dDescriptor& descriptor = cLayer->GetParameters(); + + // Construct optional biases object based on the value of m_BiasEnabled + Optional<TensorInfo> biases; + if (descriptor.m_BiasEnabled) + { + biases = OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType)); + } + + result = layerSupportObject.IsConvolution3dSupported( + input, + output, + descriptor, + OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType), + biases, + reason); + break; + } case LayerType::Debug: { const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo(); @@ -1570,6 +1596,12 @@ std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvolution2d(const Convoluti return std::unique_ptr<IWorkload>(); } +std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvolution3d(const Convolution3dQueueDescriptor& /*descriptor*/, + const WorkloadInfo& /*info*/) const +{ + return std::unique_ptr<IWorkload>(); +} + std::unique_ptr<IWorkload> IWorkloadFactory::CreateDebug(const DebugQueueDescriptor& /*descriptor*/, const WorkloadInfo& /*info*/) const { |