diff options
author | Aron Virginas-Tar <Aron.Virginas-Tar@arm.com> | 2019-02-13 15:41:52 +0000 |
---|---|---|
committer | Aron Virginas-Tar <aron.virginas-tar@arm.com> | 2019-02-18 14:49:04 +0000 |
commit | fc413c0c977e6c9680a2aa6546e977be0a2efdb9 (patch) | |
tree | c8eaeef557d1a8a0e5a18a7104f5b6d308c9efc2 /src/armnnDeserializeParser/DeserializeParser.cpp | |
parent | 2ee88dfe7096f8f571ae7be9cbf0f49ededd89af (diff) | |
download | armnn-fc413c0c977e6c9680a2aa6546e977be0a2efdb9.tar.gz |
IVGCVSW-2644 Add Serializer & Deserializer for Softmax
Change-Id: Ifea2108e173d2b602162fe53b880a68e1c715510
Signed-off-by: Aron Virginas-Tar <Aron.Virginas-Tar@arm.com>
Diffstat (limited to 'src/armnnDeserializeParser/DeserializeParser.cpp')
-rw-r--r-- | src/armnnDeserializeParser/DeserializeParser.cpp | 36 |
1 files changed, 32 insertions, 4 deletions
diff --git a/src/armnnDeserializeParser/DeserializeParser.cpp b/src/armnnDeserializeParser/DeserializeParser.cpp index 6a6a0fafe2..eb7bccaa1d 100644 --- a/src/armnnDeserializeParser/DeserializeParser.cpp +++ b/src/armnnDeserializeParser/DeserializeParser.cpp @@ -28,9 +28,11 @@ using armnn::ParseException; using namespace armnn; using namespace armnn::armnnSerializer; -namespace armnnDeserializeParser { +namespace armnnDeserializeParser +{ -namespace { +namespace +{ const uint32_t VIRTUAL_LAYER_ID = std::numeric_limits<uint32_t>::max(); @@ -132,8 +134,9 @@ DeserializeParser::DeserializeParser() m_ParserFunctions(Layer_MAX+1, &DeserializeParser::ParseUnsupportedLayer) { // register supported layers - m_ParserFunctions[Layer_AdditionLayer] = &DeserializeParser::ParseAdd; - m_ParserFunctions[Layer_MultiplicationLayer] = &DeserializeParser::ParseMultiplication; + m_ParserFunctions[Layer_AdditionLayer] = &DeserializeParser::ParseAdd; + m_ParserFunctions[Layer_MultiplicationLayer] = &DeserializeParser::ParseMultiplication; + m_ParserFunctions[Layer_SoftmaxLayer] = &DeserializeParser::ParseSoftmax; } DeserializeParser::LayerBaseRawPtr DeserializeParser::GetBaseLayer(const GraphPtr& graphPtr, unsigned int layerIndex) @@ -150,6 +153,8 @@ DeserializeParser::LayerBaseRawPtr DeserializeParser::GetBaseLayer(const GraphPt return graphPtr->layers()->Get(layerIndex)->layer_as_MultiplicationLayer()->base(); case Layer::Layer_OutputLayer: return graphPtr->layers()->Get(layerIndex)->layer_as_OutputLayer()->base()->base(); + case Layer::Layer_SoftmaxLayer: + return graphPtr->layers()->Get(layerIndex)->layer_as_SoftmaxLayer()->base(); case Layer::Layer_NONE: default: throw ParseException(boost::str( @@ -606,4 +611,27 @@ void DeserializeParser::ParseMultiplication(unsigned int layerIndex) RegisterOutputSlots(layerIndex, layer); } +void DeserializeParser::ParseSoftmax(unsigned int layerIndex) +{ + CHECK_LAYERS(m_Graph, 0, layerIndex); + + DeserializeParser::TensorRawPtrVector inputs = GetInputs(m_Graph, layerIndex); + CHECK_VALID_SIZE(inputs.size(), 1); + + DeserializeParser::TensorRawPtrVector outputs = GetOutputs(m_Graph, layerIndex); + CHECK_VALID_SIZE(outputs.size(), 1); + + armnn::SoftmaxDescriptor descriptor; + descriptor.m_Beta = m_Graph->layers()->Get(layerIndex)->layer_as_SoftmaxLayer()->descriptor()->beta(); + + const std::string layerName = boost::str(boost::format("Softmax:%1%") % layerIndex); + IConnectableLayer* layer = m_Network->AddSoftmaxLayer(descriptor, layerName.c_str()); + + armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]); + layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo); + + RegisterInputSlots(layerIndex, layer); + RegisterOutputSlots(layerIndex, layer); } + +} // namespace armnnDeserializeParser |