diff options
Diffstat (limited to 'src/armnnDeserializeParser')
-rw-r--r-- | src/armnnDeserializeParser/DeserializeParser.cpp | 36 | ||||
-rw-r--r-- | src/armnnDeserializeParser/DeserializeParser.hpp | 7 | ||||
-rw-r--r-- | src/armnnDeserializeParser/DeserializerSupport.md | 1 |
3 files changed, 37 insertions, 7 deletions
diff --git a/src/armnnDeserializeParser/DeserializeParser.cpp b/src/armnnDeserializeParser/DeserializeParser.cpp index 6a6a0fafe2..eb7bccaa1d 100644 --- a/src/armnnDeserializeParser/DeserializeParser.cpp +++ b/src/armnnDeserializeParser/DeserializeParser.cpp @@ -28,9 +28,11 @@ using armnn::ParseException; using namespace armnn; using namespace armnn::armnnSerializer; -namespace armnnDeserializeParser { +namespace armnnDeserializeParser +{ -namespace { +namespace +{ const uint32_t VIRTUAL_LAYER_ID = std::numeric_limits<uint32_t>::max(); @@ -132,8 +134,9 @@ DeserializeParser::DeserializeParser() m_ParserFunctions(Layer_MAX+1, &DeserializeParser::ParseUnsupportedLayer) { // register supported layers - m_ParserFunctions[Layer_AdditionLayer] = &DeserializeParser::ParseAdd; - m_ParserFunctions[Layer_MultiplicationLayer] = &DeserializeParser::ParseMultiplication; + m_ParserFunctions[Layer_AdditionLayer] = &DeserializeParser::ParseAdd; + m_ParserFunctions[Layer_MultiplicationLayer] = &DeserializeParser::ParseMultiplication; + m_ParserFunctions[Layer_SoftmaxLayer] = &DeserializeParser::ParseSoftmax; } DeserializeParser::LayerBaseRawPtr DeserializeParser::GetBaseLayer(const GraphPtr& graphPtr, unsigned int layerIndex) @@ -150,6 +153,8 @@ DeserializeParser::LayerBaseRawPtr DeserializeParser::GetBaseLayer(const GraphPt return graphPtr->layers()->Get(layerIndex)->layer_as_MultiplicationLayer()->base(); case Layer::Layer_OutputLayer: return graphPtr->layers()->Get(layerIndex)->layer_as_OutputLayer()->base()->base(); + case Layer::Layer_SoftmaxLayer: + return graphPtr->layers()->Get(layerIndex)->layer_as_SoftmaxLayer()->base(); case Layer::Layer_NONE: default: throw ParseException(boost::str( @@ -606,4 +611,27 @@ void DeserializeParser::ParseMultiplication(unsigned int layerIndex) RegisterOutputSlots(layerIndex, layer); } +void DeserializeParser::ParseSoftmax(unsigned int layerIndex) +{ + CHECK_LAYERS(m_Graph, 0, layerIndex); + + DeserializeParser::TensorRawPtrVector inputs = GetInputs(m_Graph, layerIndex); + CHECK_VALID_SIZE(inputs.size(), 1); + + DeserializeParser::TensorRawPtrVector outputs = GetOutputs(m_Graph, layerIndex); + CHECK_VALID_SIZE(outputs.size(), 1); + + armnn::SoftmaxDescriptor descriptor; + descriptor.m_Beta = m_Graph->layers()->Get(layerIndex)->layer_as_SoftmaxLayer()->descriptor()->beta(); + + const std::string layerName = boost::str(boost::format("Softmax:%1%") % layerIndex); + IConnectableLayer* layer = m_Network->AddSoftmaxLayer(descriptor, layerName.c_str()); + + armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]); + layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo); + + RegisterInputSlots(layerIndex, layer); + RegisterOutputSlots(layerIndex, layer); } + +} // namespace armnnDeserializeParser diff --git a/src/armnnDeserializeParser/DeserializeParser.hpp b/src/armnnDeserializeParser/DeserializeParser.hpp index ce343dc528..ddd02abede 100644 --- a/src/armnnDeserializeParser/DeserializeParser.hpp +++ b/src/armnnDeserializeParser/DeserializeParser.hpp @@ -62,9 +62,10 @@ private: // signature for the parser functions using LayerParsingFunction = void(DeserializeParser::*)(unsigned int layerIndex); - void ParseUnsupportedLayer(unsigned int serializeGraphIndex); - void ParseAdd(unsigned int serializeGraphIndex); - void ParseMultiplication(unsigned int serializeGraphIndex); + void ParseUnsupportedLayer(unsigned int layerIndex); + void ParseAdd(unsigned int layerIndex); + void ParseMultiplication(unsigned int layerIndex); + void ParseSoftmax(unsigned int layerIndex); void RegisterOutputSlotOfConnection(uint32_t connectionIndex, armnn::IOutputSlot* slot); void RegisterInputSlotOfConnection(uint32_t connectionIndex, armnn::IInputSlot* slot); diff --git a/src/armnnDeserializeParser/DeserializerSupport.md b/src/armnnDeserializeParser/DeserializerSupport.md index 8e1433419e..d4925cc0ad 100644 --- a/src/armnnDeserializeParser/DeserializerSupport.md +++ b/src/armnnDeserializeParser/DeserializerSupport.md @@ -8,5 +8,6 @@ The Arm NN SDK Deserialize parser currently supports the following layers: * Addition * Multiplication +* Softmax More machine learning layers will be supported in future releases. |