diff options
Diffstat (limited to 'src/armnnDeserializeParser/DeserializeParser.cpp')
-rw-r--r-- | src/armnnDeserializeParser/DeserializeParser.cpp | 36 |
1 files changed, 32 insertions, 4 deletions
diff --git a/src/armnnDeserializeParser/DeserializeParser.cpp b/src/armnnDeserializeParser/DeserializeParser.cpp index 6a6a0fafe2..eb7bccaa1d 100644 --- a/src/armnnDeserializeParser/DeserializeParser.cpp +++ b/src/armnnDeserializeParser/DeserializeParser.cpp @@ -28,9 +28,11 @@ using armnn::ParseException; using namespace armnn; using namespace armnn::armnnSerializer; -namespace armnnDeserializeParser { +namespace armnnDeserializeParser +{ -namespace { +namespace +{ const uint32_t VIRTUAL_LAYER_ID = std::numeric_limits<uint32_t>::max(); @@ -132,8 +134,9 @@ DeserializeParser::DeserializeParser() m_ParserFunctions(Layer_MAX+1, &DeserializeParser::ParseUnsupportedLayer) { // register supported layers - m_ParserFunctions[Layer_AdditionLayer] = &DeserializeParser::ParseAdd; - m_ParserFunctions[Layer_MultiplicationLayer] = &DeserializeParser::ParseMultiplication; + m_ParserFunctions[Layer_AdditionLayer] = &DeserializeParser::ParseAdd; + m_ParserFunctions[Layer_MultiplicationLayer] = &DeserializeParser::ParseMultiplication; + m_ParserFunctions[Layer_SoftmaxLayer] = &DeserializeParser::ParseSoftmax; } DeserializeParser::LayerBaseRawPtr DeserializeParser::GetBaseLayer(const GraphPtr& graphPtr, unsigned int layerIndex) @@ -150,6 +153,8 @@ DeserializeParser::LayerBaseRawPtr DeserializeParser::GetBaseLayer(const GraphPt return graphPtr->layers()->Get(layerIndex)->layer_as_MultiplicationLayer()->base(); case Layer::Layer_OutputLayer: return graphPtr->layers()->Get(layerIndex)->layer_as_OutputLayer()->base()->base(); + case Layer::Layer_SoftmaxLayer: + return graphPtr->layers()->Get(layerIndex)->layer_as_SoftmaxLayer()->base(); case Layer::Layer_NONE: default: throw ParseException(boost::str( @@ -606,4 +611,27 @@ void DeserializeParser::ParseMultiplication(unsigned int layerIndex) RegisterOutputSlots(layerIndex, layer); } +void DeserializeParser::ParseSoftmax(unsigned int layerIndex) +{ + CHECK_LAYERS(m_Graph, 0, layerIndex); + + DeserializeParser::TensorRawPtrVector inputs = GetInputs(m_Graph, layerIndex); + CHECK_VALID_SIZE(inputs.size(), 1); + + DeserializeParser::TensorRawPtrVector outputs = GetOutputs(m_Graph, layerIndex); + CHECK_VALID_SIZE(outputs.size(), 1); + + armnn::SoftmaxDescriptor descriptor; + descriptor.m_Beta = m_Graph->layers()->Get(layerIndex)->layer_as_SoftmaxLayer()->descriptor()->beta(); + + const std::string layerName = boost::str(boost::format("Softmax:%1%") % layerIndex); + IConnectableLayer* layer = m_Network->AddSoftmaxLayer(descriptor, layerName.c_str()); + + armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]); + layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo); + + RegisterInputSlots(layerIndex, layer); + RegisterOutputSlots(layerIndex, layer); } + +} // namespace armnnDeserializeParser |