diff options
Diffstat (limited to 'src/armnnDeserializer')
-rw-r--r-- | src/armnnDeserializer/Deserializer.cpp | 26 | ||||
-rw-r--r-- | src/armnnDeserializer/Deserializer.hpp | 1 | ||||
-rw-r--r-- | src/armnnDeserializer/DeserializerSupport.md | 1 |
3 files changed, 28 insertions, 0 deletions
diff --git a/src/armnnDeserializer/Deserializer.cpp b/src/armnnDeserializer/Deserializer.cpp index 09cdd7cad3..076072e888 100644 --- a/src/armnnDeserializer/Deserializer.cpp +++ b/src/armnnDeserializer/Deserializer.cpp @@ -222,6 +222,7 @@ m_ParserFunctions(Layer_MAX+1, &Deserializer::ParseUnsupportedLayer) m_ParserFunctions[Layer_SplitterLayer] = &Deserializer::ParseSplitter; m_ParserFunctions[Layer_StridedSliceLayer] = &Deserializer::ParseStridedSlice; m_ParserFunctions[Layer_SubtractionLayer] = &Deserializer::ParseSubtraction; + m_ParserFunctions[Layer_SwitchLayer] = &Deserializer::ParseSwitch; } Deserializer::LayerBaseRawPtr Deserializer::GetBaseLayer(const GraphPtr& graphPtr, unsigned int layerIndex) @@ -306,6 +307,8 @@ Deserializer::LayerBaseRawPtr Deserializer::GetBaseLayer(const GraphPtr& graphPt return graphPtr->layers()->Get(layerIndex)->layer_as_StridedSliceLayer()->base(); case Layer::Layer_SubtractionLayer: return graphPtr->layers()->Get(layerIndex)->layer_as_SubtractionLayer()->base(); + case Layer::Layer_SwitchLayer: + return graphPtr->layers()->Get(layerIndex)->layer_as_SwitchLayer()->base(); case Layer::Layer_NONE: default: throw ParseException(boost::str( @@ -2108,4 +2111,27 @@ void Deserializer::ParseMerge(GraphPtr graph, unsigned int layerIndex) RegisterOutputSlots(graph, layerIndex, layer); } +void Deserializer::ParseSwitch(GraphPtr graph, unsigned int layerIndex) +{ + CHECK_LAYERS(graph, 0, layerIndex); + auto inputs = GetInputs(graph, layerIndex); + CHECK_LOCATION(); + CHECK_VALID_SIZE(inputs.size(), 2); + + auto outputs = GetOutputs(graph, layerIndex); + CHECK_VALID_SIZE(outputs.size(), 2); + + auto layerName = GetLayerName(graph, layerIndex); + IConnectableLayer* layer = m_Network->AddSwitchLayer(layerName.c_str()); + + armnn::TensorInfo output0TensorInfo = ToTensorInfo(outputs[0]); + layer->GetOutputSlot(0).SetTensorInfo(output0TensorInfo); + + armnn::TensorInfo output1TensorInfo = ToTensorInfo(outputs[1]); + layer->GetOutputSlot(1).SetTensorInfo(output1TensorInfo); + + RegisterInputSlots(graph, layerIndex, layer); + RegisterOutputSlots(graph, layerIndex, layer); +} + } // namespace armnnDeserializer diff --git a/src/armnnDeserializer/Deserializer.hpp b/src/armnnDeserializer/Deserializer.hpp index df983d9086..dfa5b06057 100644 --- a/src/armnnDeserializer/Deserializer.hpp +++ b/src/armnnDeserializer/Deserializer.hpp @@ -114,6 +114,7 @@ private: void ParseSplitter(GraphPtr graph, unsigned int layerIndex); void ParseStridedSlice(GraphPtr graph, unsigned int layerIndex); void ParseSubtraction(GraphPtr graph, unsigned int layerIndex); + void ParseSwitch(GraphPtr graph, unsigned int layerIndex); void RegisterOutputSlotOfConnection(uint32_t sourceLayerIndex, armnn::IOutputSlot* slot); void RegisterInputSlotOfConnection(uint32_t sourceLayerIndex, uint32_t outputSlotIndex, armnn::IInputSlot* slot); diff --git a/src/armnnDeserializer/DeserializerSupport.md b/src/armnnDeserializer/DeserializerSupport.md index 4e5610c569..770f7a88b0 100644 --- a/src/armnnDeserializer/DeserializerSupport.md +++ b/src/armnnDeserializer/DeserializerSupport.md @@ -41,5 +41,6 @@ The Arm NN SDK Deserialize parser currently supports the following layers: * Splitter * StridedSlice * Subtraction +* Switch More machine learning layers will be supported in future releases. |