diff options
Diffstat (limited to 'src/backends/backendsCommon')
-rw-r--r-- | src/backends/backendsCommon/LayerSupportBase.cpp | 8 | ||||
-rw-r--r-- | src/backends/backendsCommon/LayerSupportBase.hpp | 5 | ||||
-rw-r--r-- | src/backends/backendsCommon/WorkloadData.cpp | 22 | ||||
-rw-r--r-- | src/backends/backendsCommon/WorkloadData.hpp | 5 | ||||
-rw-r--r-- | src/backends/backendsCommon/WorkloadFactory.cpp | 18 | ||||
-rw-r--r-- | src/backends/backendsCommon/WorkloadFactory.hpp | 3 | ||||
-rw-r--r-- | src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp | 2 |
7 files changed, 63 insertions, 0 deletions
diff --git a/src/backends/backendsCommon/LayerSupportBase.cpp b/src/backends/backendsCommon/LayerSupportBase.cpp index 04f822cea9..fc2d502fbd 100644 --- a/src/backends/backendsCommon/LayerSupportBase.cpp +++ b/src/backends/backendsCommon/LayerSupportBase.cpp @@ -253,6 +253,14 @@ bool LayerSupportBase::IsMemCopySupported(const armnn::TensorInfo& input, return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported); } +bool LayerSupportBase::IsMergeSupported(const TensorInfo& input0, + const TensorInfo& input1, + const TensorInfo& output, + Optional<std::string&> reasonIfUnsupported) const +{ + return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported); +} + bool LayerSupportBase::IsMergerSupported(const std::vector<const TensorInfo*> inputs, const TensorInfo& output, const OriginsDescriptor& descriptor, diff --git a/src/backends/backendsCommon/LayerSupportBase.hpp b/src/backends/backendsCommon/LayerSupportBase.hpp index 7d64095667..7c38b67379 100644 --- a/src/backends/backendsCommon/LayerSupportBase.hpp +++ b/src/backends/backendsCommon/LayerSupportBase.hpp @@ -160,6 +160,11 @@ public: const TensorInfo& output, Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override; + bool IsMergeSupported(const TensorInfo& input0, + const TensorInfo& input1, + const TensorInfo& output, + Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override; + bool IsMergerSupported(const std::vector<const TensorInfo*> inputs, const TensorInfo& output, const OriginsDescriptor& descriptor, diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp index 91b1c5790b..348c864863 100644 --- a/src/backends/backendsCommon/WorkloadData.cpp +++ b/src/backends/backendsCommon/WorkloadData.cpp @@ -1170,6 +1170,28 @@ void DequantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const } } +void MergeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const +{ + ValidateTwoInputs(workloadInfo, "MergeQueueDescriptor"); + ValidateSingleOutput(workloadInfo, "MergeQueueDescriptor"); + + ValidateTensorShapesMatch(workloadInfo.m_InputTensorInfos[0], + workloadInfo.m_InputTensorInfos[1], + "MergeQueueDescriptor", + "input0", + "input1"); + + ValidateTensorShapesMatch(workloadInfo.m_InputTensorInfos[0], + workloadInfo.m_OutputTensorInfos[0], + "MergeQueueDescriptor", + "input0", + "output"); + + const DataType dataType = workloadInfo.m_InputTensorInfos[0].GetDataType(); + ValidateTensorDataType(workloadInfo.m_InputTensorInfos[1], dataType, "MergeQueueDescriptor", "input1"); + ValidateTensorDataType(workloadInfo.m_OutputTensorInfos[0], dataType, "MergeQueueDescriptor", "output"); +} + void PreCompiledQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { // This is internally generated so it should not need validation. diff --git a/src/backends/backendsCommon/WorkloadData.hpp b/src/backends/backendsCommon/WorkloadData.hpp index 5640701d82..1bf735288d 100644 --- a/src/backends/backendsCommon/WorkloadData.hpp +++ b/src/backends/backendsCommon/WorkloadData.hpp @@ -421,4 +421,9 @@ struct DequantizeQueueDescriptor : QueueDescriptor void Validate(const WorkloadInfo& workloadInfo) const; }; +struct MergeQueueDescriptor : QueueDescriptor +{ + void Validate(const WorkloadInfo& workloadInfo) const; +}; + } //namespace armnn diff --git a/src/backends/backendsCommon/WorkloadFactory.cpp b/src/backends/backendsCommon/WorkloadFactory.cpp index 6534a00343..4ea3ea9f9b 100644 --- a/src/backends/backendsCommon/WorkloadFactory.cpp +++ b/src/backends/backendsCommon/WorkloadFactory.cpp @@ -519,6 +519,18 @@ bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId, reason); break; } + case LayerType::Merge: + { + const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo(); + const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo(); + const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo(); + + result = layerSupportObject->IsMergeSupported(OverrideDataType(input0, dataType), + OverrideDataType(input1, dataType), + OverrideDataType(output, dataType), + reason); + break; + } case LayerType::Merger: { auto cLayer = boost::polymorphic_downcast<const MergerLayer*>(&layer); @@ -915,6 +927,12 @@ std::unique_ptr<IWorkload> IWorkloadFactory::CreateMemCopy(const MemCopyQueueDes return std::unique_ptr<IWorkload>(); } +std::unique_ptr<IWorkload> IWorkloadFactory::CreateMerge(const MergeQueueDescriptor& descriptor, + const WorkloadInfo& info) const +{ + return std::unique_ptr<IWorkload>(); +} + std::unique_ptr<IWorkload> IWorkloadFactory::CreateMerger(const MergerQueueDescriptor& descriptor, const WorkloadInfo& info) const { diff --git a/src/backends/backendsCommon/WorkloadFactory.hpp b/src/backends/backendsCommon/WorkloadFactory.hpp index ed7303cf33..889bc9d595 100644 --- a/src/backends/backendsCommon/WorkloadFactory.hpp +++ b/src/backends/backendsCommon/WorkloadFactory.hpp @@ -121,6 +121,9 @@ public: virtual std::unique_ptr<IWorkload> CreateMemCopy(const MemCopyQueueDescriptor& descriptor, const WorkloadInfo& info) const; + virtual std::unique_ptr<IWorkload> CreateMerge(const MergeQueueDescriptor& descriptor, + const WorkloadInfo& info) const; + virtual std::unique_ptr<IWorkload> CreateMerger(const MergerQueueDescriptor& descriptor, const WorkloadInfo& info) const; diff --git a/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp b/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp index 26fb03f55d..0588607a82 100644 --- a/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp +++ b/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp @@ -362,6 +362,8 @@ DECLARE_LAYER_POLICY_1_PARAM(Maximum) DECLARE_LAYER_POLICY_2_PARAM(Mean) +DECLARE_LAYER_POLICY_1_PARAM(Merge) + DECLARE_LAYER_POLICY_2_PARAM(Merger) DECLARE_LAYER_POLICY_1_PARAM(Minimum) |