aboutsummaryrefslogtreecommitdiff
path: root/src/backends/backendsCommon
diff options
context:
space:
mode:
authorNattapat Chaimanowong <nattapat.chaimanowong@arm.com>2019-04-05 13:37:19 +0100
committerNattapat Chaimanowong <nattapat.chaimanowong@arm.com>2019-04-05 13:37:29 +0100
commit1f88630874fe346cd0cca8d8e38e0fb96cc1a3f4 (patch)
tree41acf0281797c5d4e9e515032ac989428efcb5b8 /src/backends/backendsCommon
parent647aab364aa13490427533c427496ad725b47f7a (diff)
downloadarmnn-1f88630874fe346cd0cca8d8e38e0fb96cc1a3f4.tar.gz
IVGCVSW-2915 Add Merge Layer and no-op factory method
Change-Id: I54549671e0d3b207904cf9796a843eb2b0a631f7 Signed-off-by: Nattapat Chaimanowong <nattapat.chaimanowong@arm.com>
Diffstat (limited to 'src/backends/backendsCommon')
-rw-r--r--src/backends/backendsCommon/LayerSupportBase.cpp8
-rw-r--r--src/backends/backendsCommon/LayerSupportBase.hpp5
-rw-r--r--src/backends/backendsCommon/WorkloadData.cpp22
-rw-r--r--src/backends/backendsCommon/WorkloadData.hpp5
-rw-r--r--src/backends/backendsCommon/WorkloadFactory.cpp18
-rw-r--r--src/backends/backendsCommon/WorkloadFactory.hpp3
-rw-r--r--src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp2
7 files changed, 63 insertions, 0 deletions
diff --git a/src/backends/backendsCommon/LayerSupportBase.cpp b/src/backends/backendsCommon/LayerSupportBase.cpp
index 04f822cea9..fc2d502fbd 100644
--- a/src/backends/backendsCommon/LayerSupportBase.cpp
+++ b/src/backends/backendsCommon/LayerSupportBase.cpp
@@ -253,6 +253,14 @@ bool LayerSupportBase::IsMemCopySupported(const armnn::TensorInfo& input,
return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported);
}
+bool LayerSupportBase::IsMergeSupported(const TensorInfo& input0,
+ const TensorInfo& input1,
+ const TensorInfo& output,
+ Optional<std::string&> reasonIfUnsupported) const
+{
+ return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported);
+}
+
bool LayerSupportBase::IsMergerSupported(const std::vector<const TensorInfo*> inputs,
const TensorInfo& output,
const OriginsDescriptor& descriptor,
diff --git a/src/backends/backendsCommon/LayerSupportBase.hpp b/src/backends/backendsCommon/LayerSupportBase.hpp
index 7d64095667..7c38b67379 100644
--- a/src/backends/backendsCommon/LayerSupportBase.hpp
+++ b/src/backends/backendsCommon/LayerSupportBase.hpp
@@ -160,6 +160,11 @@ public:
const TensorInfo& output,
Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
+ bool IsMergeSupported(const TensorInfo& input0,
+ const TensorInfo& input1,
+ const TensorInfo& output,
+ Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
+
bool IsMergerSupported(const std::vector<const TensorInfo*> inputs,
const TensorInfo& output,
const OriginsDescriptor& descriptor,
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp
index 91b1c5790b..348c864863 100644
--- a/src/backends/backendsCommon/WorkloadData.cpp
+++ b/src/backends/backendsCommon/WorkloadData.cpp
@@ -1170,6 +1170,28 @@ void DequantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
}
}
+void MergeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
+{
+ ValidateTwoInputs(workloadInfo, "MergeQueueDescriptor");
+ ValidateSingleOutput(workloadInfo, "MergeQueueDescriptor");
+
+ ValidateTensorShapesMatch(workloadInfo.m_InputTensorInfos[0],
+ workloadInfo.m_InputTensorInfos[1],
+ "MergeQueueDescriptor",
+ "input0",
+ "input1");
+
+ ValidateTensorShapesMatch(workloadInfo.m_InputTensorInfos[0],
+ workloadInfo.m_OutputTensorInfos[0],
+ "MergeQueueDescriptor",
+ "input0",
+ "output");
+
+ const DataType dataType = workloadInfo.m_InputTensorInfos[0].GetDataType();
+ ValidateTensorDataType(workloadInfo.m_InputTensorInfos[1], dataType, "MergeQueueDescriptor", "input1");
+ ValidateTensorDataType(workloadInfo.m_OutputTensorInfos[0], dataType, "MergeQueueDescriptor", "output");
+}
+
void PreCompiledQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
// This is internally generated so it should not need validation.
diff --git a/src/backends/backendsCommon/WorkloadData.hpp b/src/backends/backendsCommon/WorkloadData.hpp
index 5640701d82..1bf735288d 100644
--- a/src/backends/backendsCommon/WorkloadData.hpp
+++ b/src/backends/backendsCommon/WorkloadData.hpp
@@ -421,4 +421,9 @@ struct DequantizeQueueDescriptor : QueueDescriptor
void Validate(const WorkloadInfo& workloadInfo) const;
};
+struct MergeQueueDescriptor : QueueDescriptor
+{
+ void Validate(const WorkloadInfo& workloadInfo) const;
+};
+
} //namespace armnn
diff --git a/src/backends/backendsCommon/WorkloadFactory.cpp b/src/backends/backendsCommon/WorkloadFactory.cpp
index 6534a00343..4ea3ea9f9b 100644
--- a/src/backends/backendsCommon/WorkloadFactory.cpp
+++ b/src/backends/backendsCommon/WorkloadFactory.cpp
@@ -519,6 +519,18 @@ bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
reason);
break;
}
+ case LayerType::Merge:
+ {
+ const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
+ const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
+ const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
+
+ result = layerSupportObject->IsMergeSupported(OverrideDataType(input0, dataType),
+ OverrideDataType(input1, dataType),
+ OverrideDataType(output, dataType),
+ reason);
+ break;
+ }
case LayerType::Merger:
{
auto cLayer = boost::polymorphic_downcast<const MergerLayer*>(&layer);
@@ -915,6 +927,12 @@ std::unique_ptr<IWorkload> IWorkloadFactory::CreateMemCopy(const MemCopyQueueDes
return std::unique_ptr<IWorkload>();
}
+std::unique_ptr<IWorkload> IWorkloadFactory::CreateMerge(const MergeQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const
+{
+ return std::unique_ptr<IWorkload>();
+}
+
std::unique_ptr<IWorkload> IWorkloadFactory::CreateMerger(const MergerQueueDescriptor& descriptor,
const WorkloadInfo& info) const
{
diff --git a/src/backends/backendsCommon/WorkloadFactory.hpp b/src/backends/backendsCommon/WorkloadFactory.hpp
index ed7303cf33..889bc9d595 100644
--- a/src/backends/backendsCommon/WorkloadFactory.hpp
+++ b/src/backends/backendsCommon/WorkloadFactory.hpp
@@ -121,6 +121,9 @@ public:
virtual std::unique_ptr<IWorkload> CreateMemCopy(const MemCopyQueueDescriptor& descriptor,
const WorkloadInfo& info) const;
+ virtual std::unique_ptr<IWorkload> CreateMerge(const MergeQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const;
+
virtual std::unique_ptr<IWorkload> CreateMerger(const MergerQueueDescriptor& descriptor,
const WorkloadInfo& info) const;
diff --git a/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp b/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp
index 26fb03f55d..0588607a82 100644
--- a/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp
+++ b/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp
@@ -362,6 +362,8 @@ DECLARE_LAYER_POLICY_1_PARAM(Maximum)
DECLARE_LAYER_POLICY_2_PARAM(Mean)
+DECLARE_LAYER_POLICY_1_PARAM(Merge)
+
DECLARE_LAYER_POLICY_2_PARAM(Merger)
DECLARE_LAYER_POLICY_1_PARAM(Minimum)