diff options
Diffstat (limited to 'src/armnnDeserializer')
-rw-r--r-- | src/armnnDeserializer/Deserializer.cpp | 23 | ||||
-rw-r--r-- | src/armnnDeserializer/Deserializer.hpp | 1 | ||||
-rw-r--r-- | src/armnnDeserializer/DeserializerSupport.md | 1 |
3 files changed, 25 insertions, 0 deletions
diff --git a/src/armnnDeserializer/Deserializer.cpp b/src/armnnDeserializer/Deserializer.cpp index 943c6a7fed..09cdd7cad3 100644 --- a/src/armnnDeserializer/Deserializer.cpp +++ b/src/armnnDeserializer/Deserializer.cpp @@ -206,6 +206,7 @@ m_ParserFunctions(Layer_MAX+1, &Deserializer::ParseUnsupportedLayer) m_ParserFunctions[Layer_MaximumLayer] = &Deserializer::ParseMaximum; m_ParserFunctions[Layer_MeanLayer] = &Deserializer::ParseMean; m_ParserFunctions[Layer_MinimumLayer] = &Deserializer::ParseMinimum; + m_ParserFunctions[Layer_MergeLayer] = &Deserializer::ParseMerge; m_ParserFunctions[Layer_MergerLayer] = &Deserializer::ParseMerger; m_ParserFunctions[Layer_MultiplicationLayer] = &Deserializer::ParseMultiplication; m_ParserFunctions[Layer_NormalizationLayer] = &Deserializer::ParseNormalization; @@ -271,6 +272,8 @@ Deserializer::LayerBaseRawPtr Deserializer::GetBaseLayer(const GraphPtr& graphPt return graphPtr->layers()->Get(layerIndex)->layer_as_MinimumLayer()->base(); case Layer::Layer_MaximumLayer: return graphPtr->layers()->Get(layerIndex)->layer_as_MaximumLayer()->base(); + case Layer::Layer_MergeLayer: + return graphPtr->layers()->Get(layerIndex)->layer_as_MergeLayer()->base(); case Layer::Layer_MergerLayer: return graphPtr->layers()->Get(layerIndex)->layer_as_MergerLayer()->base(); case Layer::Layer_MultiplicationLayer: @@ -2085,4 +2088,24 @@ void Deserializer::ParseDequantize(GraphPtr graph, unsigned int layerIndex) RegisterOutputSlots(graph, layerIndex, layer); } +void Deserializer::ParseMerge(GraphPtr graph, unsigned int layerIndex) +{ + CHECK_LAYERS(graph, 0, layerIndex); + + Deserializer::TensorRawPtrVector inputs = GetInputs(graph, layerIndex); + CHECK_VALID_SIZE(inputs.size(), 2); + + Deserializer::TensorRawPtrVector outputs = GetOutputs(graph, layerIndex); + CHECK_VALID_SIZE(outputs.size(), 1); + + const std::string layerName = GetLayerName(graph, layerIndex); + IConnectableLayer* layer = m_Network->AddMergeLayer(layerName.c_str()); + + armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]); + layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo); + + RegisterInputSlots(graph, layerIndex, layer); + RegisterOutputSlots(graph, layerIndex, layer); +} + } // namespace armnnDeserializer diff --git a/src/armnnDeserializer/Deserializer.hpp b/src/armnnDeserializer/Deserializer.hpp index f18c163035..df983d9086 100644 --- a/src/armnnDeserializer/Deserializer.hpp +++ b/src/armnnDeserializer/Deserializer.hpp @@ -97,6 +97,7 @@ private: void ParseMaximum(GraphPtr graph, unsigned int layerIndex); void ParseMean(GraphPtr graph, unsigned int layerIndex); void ParseMinimum(GraphPtr graph, unsigned int layerIndex); + void ParseMerge(GraphPtr graph, unsigned int layerIndex); void ParseMerger(GraphPtr graph, unsigned int layerIndex); void ParseMultiplication(GraphPtr graph, unsigned int layerIndex); void ParseNormalization(GraphPtr graph, unsigned int layerIndex); diff --git a/src/armnnDeserializer/DeserializerSupport.md b/src/armnnDeserializer/DeserializerSupport.md index 77856cf389..4e5610c569 100644 --- a/src/armnnDeserializer/DeserializerSupport.md +++ b/src/armnnDeserializer/DeserializerSupport.md @@ -25,6 +25,7 @@ The Arm NN SDK Deserialize parser currently supports the following layers: * Lstm * Maximum * Mean +* Merge * Merger * Minimum * Multiplication |