aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNattapat Chaimanowong <nattapat.chaimanowong@arm.com>2019-04-05 13:37:19 +0100
committerNattapat Chaimanowong <nattapat.chaimanowong@arm.com>2019-04-05 13:37:29 +0100
commit1f88630874fe346cd0cca8d8e38e0fb96cc1a3f4 (patch)
tree41acf0281797c5d4e9e515032ac989428efcb5b8
parent647aab364aa13490427533c427496ad725b47f7a (diff)
downloadarmnn-1f88630874fe346cd0cca8d8e38e0fb96cc1a3f4.tar.gz
IVGCVSW-2915 Add Merge Layer and no-op factory method
Change-Id: I54549671e0d3b207904cf9796a843eb2b0a631f7 Signed-off-by: Nattapat Chaimanowong <nattapat.chaimanowong@arm.com>
-rw-r--r--Android.mk1
-rw-r--r--CMakeLists.txt2
-rw-r--r--include/armnn/ILayerSupport.hpp5
-rw-r--r--include/armnn/ILayerVisitor.hpp8
-rw-r--r--include/armnn/INetwork.hpp5
-rw-r--r--include/armnn/LayerSupport.hpp8
-rw-r--r--include/armnn/LayerVisitorBase.hpp3
-rw-r--r--src/armnn/InternalTypes.cpp1
-rw-r--r--src/armnn/InternalTypes.hpp1
-rw-r--r--src/armnn/LayerSupport.cpp10
-rw-r--r--src/armnn/LayersFwd.hpp2
-rw-r--r--src/armnn/Network.cpp5
-rw-r--r--src/armnn/Network.hpp2
-rw-r--r--src/armnn/layers/MergeLayer.cpp65
-rw-r--r--src/armnn/layers/MergeLayer.hpp47
-rw-r--r--src/armnn/test/NetworkTests.cpp52
-rw-r--r--src/armnnDeserializer/Deserializer.cpp23
-rw-r--r--src/armnnDeserializer/Deserializer.hpp1
-rw-r--r--src/armnnDeserializer/DeserializerSupport.md1
-rw-r--r--src/armnnSerializer/ArmnnSchema.fbs10
-rw-r--r--src/armnnSerializer/Serializer.cpp8
-rw-r--r--src/armnnSerializer/Serializer.hpp3
-rw-r--r--src/armnnSerializer/SerializerSupport.md1
-rw-r--r--src/armnnSerializer/test/SerializerTests.cpp40
-rw-r--r--src/backends/backendsCommon/LayerSupportBase.cpp8
-rw-r--r--src/backends/backendsCommon/LayerSupportBase.hpp5
-rw-r--r--src/backends/backendsCommon/WorkloadData.cpp22
-rw-r--r--src/backends/backendsCommon/WorkloadData.hpp5
-rw-r--r--src/backends/backendsCommon/WorkloadFactory.cpp18
-rw-r--r--src/backends/backendsCommon/WorkloadFactory.hpp3
-rw-r--r--src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp2
31 files changed, 364 insertions, 3 deletions
diff --git a/Android.mk b/Android.mk
index 85bd214379..6d5a0faa66 100644
--- a/Android.mk
+++ b/Android.mk
@@ -108,6 +108,7 @@ LOCAL_SRC_FILES := \
src/armnn/layers/MaximumLayer.cpp \
src/armnn/layers/MeanLayer.cpp \
src/armnn/layers/MemCopyLayer.cpp \
+ src/armnn/layers/MergeLayer.cpp \
src/armnn/layers/MergerLayer.cpp \
src/armnn/layers/MinimumLayer.cpp \
src/armnn/layers/MultiplicationLayer.cpp \
diff --git a/CMakeLists.txt b/CMakeLists.txt
index ec237aab04..d1fe635407 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -239,6 +239,8 @@ list(APPEND armnn_sources
src/armnn/layers/MeanLayer.cpp
src/armnn/layers/MemCopyLayer.hpp
src/armnn/layers/MemCopyLayer.cpp
+ src/armnn/layers/MergeLayer.hpp
+ src/armnn/layers/MergeLayer.cpp
src/armnn/layers/MergerLayer.hpp
src/armnn/layers/MergerLayer.cpp
src/armnn/layers/MinimumLayer.cpp
diff --git a/include/armnn/ILayerSupport.hpp b/include/armnn/ILayerSupport.hpp
index fe440719b0..1b75810aca 100644
--- a/include/armnn/ILayerSupport.hpp
+++ b/include/armnn/ILayerSupport.hpp
@@ -171,6 +171,11 @@ public:
const TensorInfo& output,
Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
+ virtual bool IsMergeSupported(const TensorInfo& input0,
+ const TensorInfo& input1,
+ const TensorInfo& output,
+ Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
+
virtual bool IsMergerSupported(const std::vector<const TensorInfo*> inputs,
const TensorInfo& output,
const OriginsDescriptor& descriptor,
diff --git a/include/armnn/ILayerVisitor.hpp b/include/armnn/ILayerVisitor.hpp
index e23cf5e6dd..3a4c39b7c6 100644
--- a/include/armnn/ILayerVisitor.hpp
+++ b/include/armnn/ILayerVisitor.hpp
@@ -199,6 +199,12 @@ public:
const MeanDescriptor& meanDescriptor,
const char* name = nullptr) = 0;
+ /// Function that a merge layer should call back to when its Accept(ILayerVisitor&) function is invoked.
+ /// @param layer - pointer to the layer which is calling back to this visit function.
+ /// @param name - Optional name for the layer.
+ virtual void VisitMergeLayer(const IConnectableLayer* layer,
+ const char* name = nullptr) = 0;
+
/// Function that a merger layer should call back to when its Accept(ILayerVisitor&) function is invoked.
/// @param layer - pointer to the layer which is calling back to this visit function.
/// @param mergerDescriptor - WindowsDescriptor to configure the merging process. Number of Views must be equal to
@@ -337,4 +343,4 @@ public:
};
-} // namespace armnn \ No newline at end of file
+} // namespace armnn
diff --git a/include/armnn/INetwork.hpp b/include/armnn/INetwork.hpp
index 5a9d4f246e..8243b39c36 100644
--- a/include/armnn/INetwork.hpp
+++ b/include/armnn/INetwork.hpp
@@ -235,6 +235,11 @@ public:
virtual IConnectableLayer* AddSplitterLayer(const ViewsDescriptor& splitterDescriptor
, const char* name = nullptr) = 0;
+ /// Adds a merge layer to the network.
+ /// @param name - Optional name for the layer.
+ /// @return - Interface for configuring the layer.
+ virtual IConnectableLayer* AddMergeLayer(const char* name = nullptr) = 0;
+
/// Adds a merger layer to the network.
/// @param mergerDescriptor - WindowsDescriptor to configure the merging process. Number of Views must be equal to
/// the number of inputs, and their order must match - e.g. first view corresponds to
diff --git a/include/armnn/LayerSupport.hpp b/include/armnn/LayerSupport.hpp
index 7c6bc1372a..e23fdd0a75 100644
--- a/include/armnn/LayerSupport.hpp
+++ b/include/armnn/LayerSupport.hpp
@@ -204,6 +204,14 @@ bool IsMemCopySupported(const BackendId& backend,
size_t reasonIfUnsupportedMaxLength = 1024);
/// Deprecated in favor of IBackend and ILayerSupport interfaces
+bool IsMergeSupported(const BackendId& backend,
+ const TensorInfo& input0,
+ const TensorInfo& input1,
+ const TensorInfo& output,
+ char* reasonIfUnsupported = nullptr,
+ size_t reasonIfUnsupportedMaxLength = 1024);
+
+/// Deprecated in favor of IBackend and ILayerSupport interfaces
bool IsMergerSupported(const BackendId& backend,
const std::vector<const TensorInfo*> inputs,
const TensorInfo& output,
diff --git a/include/armnn/LayerVisitorBase.hpp b/include/armnn/LayerVisitorBase.hpp
index a5459e1a32..f4e0f438be 100644
--- a/include/armnn/LayerVisitorBase.hpp
+++ b/include/armnn/LayerVisitorBase.hpp
@@ -87,6 +87,9 @@ public:
const ViewsDescriptor&,
const char*) override { DefaultPolicy::Apply(); }
+ void VisitMergeLayer(const IConnectableLayer*,
+ const char*) override { DefaultPolicy::Apply(); }
+
void VisitMergerLayer(const IConnectableLayer*,
const OriginsDescriptor&,
const char*) override { DefaultPolicy::Apply(); }
diff --git a/src/armnn/InternalTypes.cpp b/src/armnn/InternalTypes.cpp
index fe1542b162..93a4f94378 100644
--- a/src/armnn/InternalTypes.cpp
+++ b/src/armnn/InternalTypes.cpp
@@ -39,6 +39,7 @@ char const* GetLayerTypeAsCString(LayerType type)
case LayerType::Maximum: return "Maximum";
case LayerType::Mean: return "Mean";
case LayerType::MemCopy: return "MemCopy";
+ case LayerType::Merge: return "Merge";
case LayerType::Merger: return "Merger";
case LayerType::Minimum: return "Minimum";
case LayerType::Multiplication: return "Multiplication";
diff --git a/src/armnn/InternalTypes.hpp b/src/armnn/InternalTypes.hpp
index 1972e9c1b5..7c7c601d95 100644
--- a/src/armnn/InternalTypes.hpp
+++ b/src/armnn/InternalTypes.hpp
@@ -39,6 +39,7 @@ enum class LayerType
Maximum,
Mean,
MemCopy,
+ Merge,
Merger,
Minimum,
Multiplication,
diff --git a/src/armnn/LayerSupport.cpp b/src/armnn/LayerSupport.cpp
index 030973306f..bc6eec891b 100644
--- a/src/armnn/LayerSupport.cpp
+++ b/src/armnn/LayerSupport.cpp
@@ -355,6 +355,16 @@ bool IsMemCopySupported(const BackendId &backend,
FORWARD_LAYER_SUPPORT_FUNC(backend, IsMemCopySupported, input, output);
}
+bool IsMergeSupported(const BackendId& backend,
+ const TensorInfo& input0,
+ const TensorInfo& input1,
+ const TensorInfo& output,
+ char* reasonIfUnsupported,
+ size_t reasonIfUnsupportedMaxLength)
+{
+ FORWARD_LAYER_SUPPORT_FUNC(backend, IsMergeSupported, input0, input1, output);
+}
+
bool IsMergerSupported(const BackendId& backend,
std::vector<const TensorInfo*> inputs,
const TensorInfo& output,
diff --git a/src/armnn/LayersFwd.hpp b/src/armnn/LayersFwd.hpp
index 9d87aeeee3..0bd68e04af 100644
--- a/src/armnn/LayersFwd.hpp
+++ b/src/armnn/LayersFwd.hpp
@@ -31,6 +31,7 @@
#include "layers/MaximumLayer.hpp"
#include "layers/MeanLayer.hpp"
#include "layers/MemCopyLayer.hpp"
+#include "layers/MergeLayer.hpp"
#include "layers/MergerLayer.hpp"
#include "layers/MinimumLayer.hpp"
#include "layers/MultiplicationLayer.hpp"
@@ -102,6 +103,7 @@ DECLARE_LAYER(Lstm)
DECLARE_LAYER(Maximum)
DECLARE_LAYER(Mean)
DECLARE_LAYER(MemCopy)
+DECLARE_LAYER(Merge)
DECLARE_LAYER(Merger)
DECLARE_LAYER(Minimum)
DECLARE_LAYER(Multiplication)
diff --git a/src/armnn/Network.cpp b/src/armnn/Network.cpp
index 6dbd4611df..73db2e88d7 100644
--- a/src/armnn/Network.cpp
+++ b/src/armnn/Network.cpp
@@ -966,6 +966,11 @@ IConnectableLayer* Network::AddGatherLayer(const char* name)
return m_Graph->AddLayer<GatherLayer>(name);
}
+IConnectableLayer* Network::AddMergeLayer(const char* name)
+{
+ return m_Graph->AddLayer<MergeLayer>(name);
+}
+
void Network::Accept(ILayerVisitor& visitor) const
{
for (auto layer : GetGraph())
diff --git a/src/armnn/Network.hpp b/src/armnn/Network.hpp
index 782531acde..bb7b9eb6f4 100644
--- a/src/armnn/Network.hpp
+++ b/src/armnn/Network.hpp
@@ -174,6 +174,8 @@ public:
IConnectableLayer* AddRsqrtLayer(const char* name = nullptr) override;
+ IConnectableLayer* AddMergeLayer(const char* name = nullptr) override;
+
void Accept(ILayerVisitor& visitor) const override;
private:
diff --git a/src/armnn/layers/MergeLayer.cpp b/src/armnn/layers/MergeLayer.cpp
new file mode 100644
index 0000000000..1d4dc49379
--- /dev/null
+++ b/src/armnn/layers/MergeLayer.cpp
@@ -0,0 +1,65 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+#include "MergeLayer.hpp"
+
+#include "LayerCloneBase.hpp"
+
+#include <backendsCommon/WorkloadData.hpp>
+#include <backendsCommon/WorkloadFactory.hpp>
+
+namespace armnn
+{
+
+MergeLayer::MergeLayer(const char* name)
+ : Layer(2, 1, LayerType::Merge, name)
+{}
+
+std::unique_ptr<IWorkload> MergeLayer::CreateWorkload(const Graph& graph,
+ const IWorkloadFactory& factory) const
+{
+ return nullptr;
+}
+
+MergeLayer* MergeLayer::Clone(Graph& graph) const
+{
+ return CloneBase<MergeLayer>(graph, GetName());
+}
+
+void MergeLayer::ValidateTensorShapesFromInputs()
+{
+ VerifyLayerConnections(2, CHECK_LOCATION());
+
+ std::vector<TensorShape> inferredShapes = InferOutputShapes({
+ GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape(),
+ GetInputSlot(1).GetConnection()->GetTensorInfo().GetShape(),
+ });
+
+ BOOST_ASSERT(inferredShapes.size() == 1);
+
+ ConditionalThrowIfNotEqual<LayerValidationException>(
+ "MergeLayer: TensorShape set on OutputSlot[0] does not match the inferred shape.",
+ GetOutputSlot(0).GetTensorInfo().GetShape(),
+ inferredShapes[0]);
+}
+
+std::vector<TensorShape> MergeLayer::InferOutputShapes(const std::vector<TensorShape>& inputShapes) const
+{
+ BOOST_ASSERT(inputShapes.size() == 2);
+
+ ConditionalThrowIfNotEqual<LayerValidationException>(
+ "MergeLayer: TensorShapes set on inputs do not match",
+ inputShapes[0],
+ inputShapes[1]
+ );
+
+ return {inputShapes[0]};
+}
+
+void MergeLayer::Accept(ILayerVisitor& visitor) const
+{
+ visitor.VisitMergeLayer(this, GetName());
+}
+
+} // namespace armnn
diff --git a/src/armnn/layers/MergeLayer.hpp b/src/armnn/layers/MergeLayer.hpp
new file mode 100644
index 0000000000..66664ca952
--- /dev/null
+++ b/src/armnn/layers/MergeLayer.hpp
@@ -0,0 +1,47 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+#pragma once
+
+#include "Layer.hpp"
+
+namespace armnn
+{
+
+/// This layer dequantizes the input tensor.
+class MergeLayer : public Layer
+{
+public:
+ /// Makes a workload for the Merge type.
+ /// @param [in] graph The graph where this layer can be found.
+ /// @param [in] factory The workload factory which will create the workload.
+ /// @return A pointer to the created workload, or nullptr if not created.
+ virtual std::unique_ptr<IWorkload> CreateWorkload(const Graph& graph,
+ const IWorkloadFactory& factory) const override;
+
+ /// Creates a dynamically-allocated copy of this layer.
+ /// @param [in] graph The graph into which this layer is being cloned.
+ MergeLayer* Clone(Graph& graph) const override;
+
+ /// Check if the input tensor shape(s)
+ /// will lead to a valid configuration of @ref MergeLayer.
+ void ValidateTensorShapesFromInputs() override;
+
+ /// Infers the output shapes from given input shapes.
+ /// @param [in] inputShapes The input shapes layer has.
+ /// @return A vector to the inferred output shape.
+ std::vector<TensorShape> InferOutputShapes(const std::vector<TensorShape>& inputShapes) const override;
+
+ void Accept(ILayerVisitor& visitor) const override;
+
+protected:
+ /// Constructor to create a MergeLayer.
+ /// @param [in] name Optional name for the layer.
+ MergeLayer(const char* name);
+
+ /// Default destructor
+ ~MergeLayer() = default;
+};
+
+} // namespace armnn
diff --git a/src/armnn/test/NetworkTests.cpp b/src/armnn/test/NetworkTests.cpp
index 4de09a2804..dd8eb7773f 100644
--- a/src/armnn/test/NetworkTests.cpp
+++ b/src/armnn/test/NetworkTests.cpp
@@ -417,4 +417,56 @@ BOOST_AUTO_TEST_CASE(Network_AddQuantize)
}
+BOOST_AUTO_TEST_CASE(Network_AddMerge)
+{
+ struct Test : public armnn::LayerVisitorBase<armnn::VisitorNoThrowPolicy>
+ {
+ void VisitMergeLayer(const armnn::IConnectableLayer* layer, const char* name) override
+ {
+ m_Visited = true;
+
+ BOOST_TEST(layer);
+
+ std::string expectedName = std::string("merge");
+ BOOST_TEST(std::string(layer->GetName()) == expectedName);
+ BOOST_TEST(std::string(name) == expectedName);
+
+ BOOST_TEST(layer->GetNumInputSlots() == 2);
+ BOOST_TEST(layer->GetNumOutputSlots() == 1);
+
+ const armnn::TensorInfo& infoIn0 = layer->GetInputSlot(0).GetConnection()->GetTensorInfo();
+ BOOST_TEST((infoIn0.GetDataType() == armnn::DataType::Float32));
+
+ const armnn::TensorInfo& infoIn1 = layer->GetInputSlot(1).GetConnection()->GetTensorInfo();
+ BOOST_TEST((infoIn1.GetDataType() == armnn::DataType::Float32));
+
+ const armnn::TensorInfo& infoOut = layer->GetOutputSlot(0).GetTensorInfo();
+ BOOST_TEST((infoOut.GetDataType() == armnn::DataType::Float32));
+ }
+
+ bool m_Visited = false;
+ };
+
+ armnn::INetworkPtr network = armnn::INetwork::Create();
+
+ armnn::IConnectableLayer* input0 = network->AddInputLayer(0);
+ armnn::IConnectableLayer* input1 = network->AddInputLayer(1);
+ armnn::IConnectableLayer* merge = network->AddMergeLayer("merge");
+ armnn::IConnectableLayer* output = network->AddOutputLayer(0);
+
+ input0->GetOutputSlot(0).Connect(merge->GetInputSlot(0));
+ input1->GetOutputSlot(0).Connect(merge->GetInputSlot(1));
+ merge->GetOutputSlot(0).Connect(output->GetInputSlot(0));
+
+ const armnn::TensorInfo info({3,1}, armnn::DataType::Float32);
+ input0->GetOutputSlot(0).SetTensorInfo(info);
+ input1->GetOutputSlot(0).SetTensorInfo(info);
+ merge->GetOutputSlot(0).SetTensorInfo(info);
+
+ Test testMerge;
+ network->Accept(testMerge);
+
+ BOOST_TEST(testMerge.m_Visited == true);
+}
+
BOOST_AUTO_TEST_SUITE_END()
diff --git a/src/armnnDeserializer/Deserializer.cpp b/src/armnnDeserializer/Deserializer.cpp
index 943c6a7fed..09cdd7cad3 100644
--- a/src/armnnDeserializer/Deserializer.cpp
+++ b/src/armnnDeserializer/Deserializer.cpp
@@ -206,6 +206,7 @@ m_ParserFunctions(Layer_MAX+1, &Deserializer::ParseUnsupportedLayer)
m_ParserFunctions[Layer_MaximumLayer] = &Deserializer::ParseMaximum;
m_ParserFunctions[Layer_MeanLayer] = &Deserializer::ParseMean;
m_ParserFunctions[Layer_MinimumLayer] = &Deserializer::ParseMinimum;
+ m_ParserFunctions[Layer_MergeLayer] = &Deserializer::ParseMerge;
m_ParserFunctions[Layer_MergerLayer] = &Deserializer::ParseMerger;
m_ParserFunctions[Layer_MultiplicationLayer] = &Deserializer::ParseMultiplication;
m_ParserFunctions[Layer_NormalizationLayer] = &Deserializer::ParseNormalization;
@@ -271,6 +272,8 @@ Deserializer::LayerBaseRawPtr Deserializer::GetBaseLayer(const GraphPtr& graphPt
return graphPtr->layers()->Get(layerIndex)->layer_as_MinimumLayer()->base();
case Layer::Layer_MaximumLayer:
return graphPtr->layers()->Get(layerIndex)->layer_as_MaximumLayer()->base();
+ case Layer::Layer_MergeLayer:
+ return graphPtr->layers()->Get(layerIndex)->layer_as_MergeLayer()->base();
case Layer::Layer_MergerLayer:
return graphPtr->layers()->Get(layerIndex)->layer_as_MergerLayer()->base();
case Layer::Layer_MultiplicationLayer:
@@ -2085,4 +2088,24 @@ void Deserializer::ParseDequantize(GraphPtr graph, unsigned int layerIndex)
RegisterOutputSlots(graph, layerIndex, layer);
}
+void Deserializer::ParseMerge(GraphPtr graph, unsigned int layerIndex)
+{
+ CHECK_LAYERS(graph, 0, layerIndex);
+
+ Deserializer::TensorRawPtrVector inputs = GetInputs(graph, layerIndex);
+ CHECK_VALID_SIZE(inputs.size(), 2);
+
+ Deserializer::TensorRawPtrVector outputs = GetOutputs(graph, layerIndex);
+ CHECK_VALID_SIZE(outputs.size(), 1);
+
+ const std::string layerName = GetLayerName(graph, layerIndex);
+ IConnectableLayer* layer = m_Network->AddMergeLayer(layerName.c_str());
+
+ armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]);
+ layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
+
+ RegisterInputSlots(graph, layerIndex, layer);
+ RegisterOutputSlots(graph, layerIndex, layer);
+}
+
} // namespace armnnDeserializer
diff --git a/src/armnnDeserializer/Deserializer.hpp b/src/armnnDeserializer/Deserializer.hpp
index f18c163035..df983d9086 100644
--- a/src/armnnDeserializer/Deserializer.hpp
+++ b/src/armnnDeserializer/Deserializer.hpp
@@ -97,6 +97,7 @@ private:
void ParseMaximum(GraphPtr graph, unsigned int layerIndex);
void ParseMean(GraphPtr graph, unsigned int layerIndex);
void ParseMinimum(GraphPtr graph, unsigned int layerIndex);
+ void ParseMerge(GraphPtr graph, unsigned int layerIndex);
void ParseMerger(GraphPtr graph, unsigned int layerIndex);
void ParseMultiplication(GraphPtr graph, unsigned int layerIndex);
void ParseNormalization(GraphPtr graph, unsigned int layerIndex);
diff --git a/src/armnnDeserializer/DeserializerSupport.md b/src/armnnDeserializer/DeserializerSupport.md
index 77856cf389..4e5610c569 100644
--- a/src/armnnDeserializer/DeserializerSupport.md
+++ b/src/armnnDeserializer/DeserializerSupport.md
@@ -25,6 +25,7 @@ The Arm NN SDK Deserialize parser currently supports the following layers:
* Lstm
* Maximum
* Mean
+* Merge
* Merger
* Minimum
* Multiplication
diff --git a/src/armnnSerializer/ArmnnSchema.fbs b/src/armnnSerializer/ArmnnSchema.fbs
index 3aa644dbe5..8b275b6f17 100644
--- a/src/armnnSerializer/ArmnnSchema.fbs
+++ b/src/armnnSerializer/ArmnnSchema.fbs
@@ -118,7 +118,8 @@ enum LayerType : uint {
DetectionPostProcess = 33,
Lstm = 34,
Quantize = 35,
- Dequantize = 36
+ Dequantize = 36,
+ Merge = 37
}
// Base layer table to be used as part of other layers
@@ -524,6 +525,10 @@ table DequantizeLayer {
base:LayerBase;
}
+table MergeLayer {
+ base:LayerBase;
+}
+
union Layer {
ActivationLayer,
AdditionLayer,
@@ -561,7 +566,8 @@ union Layer {
DetectionPostProcessLayer,
LstmLayer,
QuantizeLayer,
- DequantizeLayer
+ DequantizeLayer,
+ MergeLayer
}
table AnyLayer {
diff --git a/src/armnnSerializer/Serializer.cpp b/src/armnnSerializer/Serializer.cpp
index 7181f01e6b..fe30c3eee5 100644
--- a/src/armnnSerializer/Serializer.cpp
+++ b/src/armnnSerializer/Serializer.cpp
@@ -500,6 +500,14 @@ void SerializerVisitor::VisitMinimumLayer(const armnn::IConnectableLayer* layer,
CreateAnyLayer(fbMinimumLayer.o, serializer::Layer::Layer_MinimumLayer);
}
+void SerializerVisitor::VisitMergeLayer(const armnn::IConnectableLayer* layer, const char* name)
+{
+ auto fbMergeBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Merge);
+ auto fbMergeLayer = serializer::CreateMergeLayer(m_flatBufferBuilder, fbMergeBaseLayer);
+
+ CreateAnyLayer(fbMergeLayer.o, serializer::Layer::Layer_MergeLayer);
+}
+
void SerializerVisitor::VisitMergerLayer(const armnn::IConnectableLayer* layer,
const armnn::OriginsDescriptor& mergerDescriptor,
const char* name)
diff --git a/src/armnnSerializer/Serializer.hpp b/src/armnnSerializer/Serializer.hpp
index 5c3e48a695..775df83966 100644
--- a/src/armnnSerializer/Serializer.hpp
+++ b/src/armnnSerializer/Serializer.hpp
@@ -129,6 +129,9 @@ public:
void VisitMaximumLayer(const armnn::IConnectableLayer* layer,
const char* name = nullptr) override;
+ void VisitMergeLayer(const armnn::IConnectableLayer* layer,
+ const char* name = nullptr) override;
+
void VisitMergerLayer(const armnn::IConnectableLayer* layer,
const armnn::OriginsDescriptor& mergerDescriptor,
const char* name = nullptr) override;
diff --git a/src/armnnSerializer/SerializerSupport.md b/src/armnnSerializer/SerializerSupport.md
index a3c5852bd2..a8335e1e68 100644
--- a/src/armnnSerializer/SerializerSupport.md
+++ b/src/armnnSerializer/SerializerSupport.md
@@ -25,6 +25,7 @@ The Arm NN SDK Serializer currently supports the following layers:
* Lstm
* Maximum
* Mean
+* Merge
* Merger
* Minimum
* Multiplication
diff --git a/src/armnnSerializer/test/SerializerTests.cpp b/src/armnnSerializer/test/SerializerTests.cpp
index 0979076476..a1ef9eef59 100644
--- a/src/armnnSerializer/test/SerializerTests.cpp
+++ b/src/armnnSerializer/test/SerializerTests.cpp
@@ -1185,6 +1185,46 @@ BOOST_AUTO_TEST_CASE(SerializeMean)
deserializedNetwork->Accept(verifier);
}
+BOOST_AUTO_TEST_CASE(SerializeMerge)
+{
+ class MergeLayerVerifier : public LayerVerifierBase
+ {
+ public:
+ MergeLayerVerifier(const std::string& layerName,
+ const std::vector<armnn::TensorInfo>& inputInfos,
+ const std::vector<armnn::TensorInfo>& outputInfos)
+ : LayerVerifierBase(layerName, inputInfos, outputInfos) {}
+
+ void VisitMergeLayer(const armnn::IConnectableLayer* layer, const char* name) override
+ {
+ VerifyNameAndConnections(layer, name);
+ }
+ };
+
+ const std::string layerName("merge");
+ const armnn::TensorInfo info({ 1, 2, 2, 3 }, armnn::DataType::Float32);
+
+ armnn::INetworkPtr network = armnn::INetwork::Create();
+ armnn::IConnectableLayer* const inputLayer0 = network->AddInputLayer(0);
+ armnn::IConnectableLayer* const inputLayer1 = network->AddInputLayer(1);
+ armnn::IConnectableLayer* const mergeLayer = network->AddMergeLayer(layerName.c_str());
+ armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(0);
+
+ inputLayer0->GetOutputSlot(0).Connect(mergeLayer->GetInputSlot(0));
+ inputLayer1->GetOutputSlot(0).Connect(mergeLayer->GetInputSlot(1));
+ mergeLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
+
+ inputLayer0->GetOutputSlot(0).SetTensorInfo(info);
+ inputLayer1->GetOutputSlot(0).SetTensorInfo(info);
+ mergeLayer->GetOutputSlot(0).SetTensorInfo(info);
+
+ armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
+ BOOST_CHECK(deserializedNetwork);
+
+ MergeLayerVerifier verifier(layerName, {info, info}, {info});
+ deserializedNetwork->Accept(verifier);
+}
+
BOOST_AUTO_TEST_CASE(SerializeMerger)
{
class MergerLayerVerifier : public LayerVerifierBase
diff --git a/src/backends/backendsCommon/LayerSupportBase.cpp b/src/backends/backendsCommon/LayerSupportBase.cpp
index 04f822cea9..fc2d502fbd 100644
--- a/src/backends/backendsCommon/LayerSupportBase.cpp
+++ b/src/backends/backendsCommon/LayerSupportBase.cpp
@@ -253,6 +253,14 @@ bool LayerSupportBase::IsMemCopySupported(const armnn::TensorInfo& input,
return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported);
}
+bool LayerSupportBase::IsMergeSupported(const TensorInfo& input0,
+ const TensorInfo& input1,
+ const TensorInfo& output,
+ Optional<std::string&> reasonIfUnsupported) const
+{
+ return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported);
+}
+
bool LayerSupportBase::IsMergerSupported(const std::vector<const TensorInfo*> inputs,
const TensorInfo& output,
const OriginsDescriptor& descriptor,
diff --git a/src/backends/backendsCommon/LayerSupportBase.hpp b/src/backends/backendsCommon/LayerSupportBase.hpp
index 7d64095667..7c38b67379 100644
--- a/src/backends/backendsCommon/LayerSupportBase.hpp
+++ b/src/backends/backendsCommon/LayerSupportBase.hpp
@@ -160,6 +160,11 @@ public:
const TensorInfo& output,
Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
+ bool IsMergeSupported(const TensorInfo& input0,
+ const TensorInfo& input1,
+ const TensorInfo& output,
+ Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
+
bool IsMergerSupported(const std::vector<const TensorInfo*> inputs,
const TensorInfo& output,
const OriginsDescriptor& descriptor,
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp
index 91b1c5790b..348c864863 100644
--- a/src/backends/backendsCommon/WorkloadData.cpp
+++ b/src/backends/backendsCommon/WorkloadData.cpp
@@ -1170,6 +1170,28 @@ void DequantizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
}
}
+void MergeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
+{
+ ValidateTwoInputs(workloadInfo, "MergeQueueDescriptor");
+ ValidateSingleOutput(workloadInfo, "MergeQueueDescriptor");
+
+ ValidateTensorShapesMatch(workloadInfo.m_InputTensorInfos[0],
+ workloadInfo.m_InputTensorInfos[1],
+ "MergeQueueDescriptor",
+ "input0",
+ "input1");
+
+ ValidateTensorShapesMatch(workloadInfo.m_InputTensorInfos[0],
+ workloadInfo.m_OutputTensorInfos[0],
+ "MergeQueueDescriptor",
+ "input0",
+ "output");
+
+ const DataType dataType = workloadInfo.m_InputTensorInfos[0].GetDataType();
+ ValidateTensorDataType(workloadInfo.m_InputTensorInfos[1], dataType, "MergeQueueDescriptor", "input1");
+ ValidateTensorDataType(workloadInfo.m_OutputTensorInfos[0], dataType, "MergeQueueDescriptor", "output");
+}
+
void PreCompiledQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
// This is internally generated so it should not need validation.
diff --git a/src/backends/backendsCommon/WorkloadData.hpp b/src/backends/backendsCommon/WorkloadData.hpp
index 5640701d82..1bf735288d 100644
--- a/src/backends/backendsCommon/WorkloadData.hpp
+++ b/src/backends/backendsCommon/WorkloadData.hpp
@@ -421,4 +421,9 @@ struct DequantizeQueueDescriptor : QueueDescriptor
void Validate(const WorkloadInfo& workloadInfo) const;
};
+struct MergeQueueDescriptor : QueueDescriptor
+{
+ void Validate(const WorkloadInfo& workloadInfo) const;
+};
+
} //namespace armnn
diff --git a/src/backends/backendsCommon/WorkloadFactory.cpp b/src/backends/backendsCommon/WorkloadFactory.cpp
index 6534a00343..4ea3ea9f9b 100644
--- a/src/backends/backendsCommon/WorkloadFactory.cpp
+++ b/src/backends/backendsCommon/WorkloadFactory.cpp
@@ -519,6 +519,18 @@ bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
reason);
break;
}
+ case LayerType::Merge:
+ {
+ const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
+ const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
+ const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
+
+ result = layerSupportObject->IsMergeSupported(OverrideDataType(input0, dataType),
+ OverrideDataType(input1, dataType),
+ OverrideDataType(output, dataType),
+ reason);
+ break;
+ }
case LayerType::Merger:
{
auto cLayer = boost::polymorphic_downcast<const MergerLayer*>(&layer);
@@ -915,6 +927,12 @@ std::unique_ptr<IWorkload> IWorkloadFactory::CreateMemCopy(const MemCopyQueueDes
return std::unique_ptr<IWorkload>();
}
+std::unique_ptr<IWorkload> IWorkloadFactory::CreateMerge(const MergeQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const
+{
+ return std::unique_ptr<IWorkload>();
+}
+
std::unique_ptr<IWorkload> IWorkloadFactory::CreateMerger(const MergerQueueDescriptor& descriptor,
const WorkloadInfo& info) const
{
diff --git a/src/backends/backendsCommon/WorkloadFactory.hpp b/src/backends/backendsCommon/WorkloadFactory.hpp
index ed7303cf33..889bc9d595 100644
--- a/src/backends/backendsCommon/WorkloadFactory.hpp
+++ b/src/backends/backendsCommon/WorkloadFactory.hpp
@@ -121,6 +121,9 @@ public:
virtual std::unique_ptr<IWorkload> CreateMemCopy(const MemCopyQueueDescriptor& descriptor,
const WorkloadInfo& info) const;
+ virtual std::unique_ptr<IWorkload> CreateMerge(const MergeQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const;
+
virtual std::unique_ptr<IWorkload> CreateMerger(const MergerQueueDescriptor& descriptor,
const WorkloadInfo& info) const;
diff --git a/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp b/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp
index 26fb03f55d..0588607a82 100644
--- a/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp
+++ b/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp
@@ -362,6 +362,8 @@ DECLARE_LAYER_POLICY_1_PARAM(Maximum)
DECLARE_LAYER_POLICY_2_PARAM(Mean)
+DECLARE_LAYER_POLICY_1_PARAM(Merge)
+
DECLARE_LAYER_POLICY_2_PARAM(Merger)
DECLARE_LAYER_POLICY_1_PARAM(Minimum)