diff options
Diffstat (limited to 'src/armnnDeserializer')
-rw-r--r-- | src/armnnDeserializer/Deserializer.cpp | 30 |
1 files changed, 24 insertions, 6 deletions
diff --git a/src/armnnDeserializer/Deserializer.cpp b/src/armnnDeserializer/Deserializer.cpp index 14cf232cdb..75c258b7ab 100644 --- a/src/armnnDeserializer/Deserializer.cpp +++ b/src/armnnDeserializer/Deserializer.cpp @@ -192,6 +192,7 @@ m_ParserFunctions(Layer_MAX+1, &Deserializer::ParseUnsupportedLayer) m_ParserFunctions[Layer_AdditionLayer] = &Deserializer::ParseAdd; m_ParserFunctions[Layer_BatchToSpaceNdLayer] = &Deserializer::ParseBatchToSpaceNd; m_ParserFunctions[Layer_BatchNormalizationLayer] = &Deserializer::ParseBatchNormalization; + m_ParserFunctions[Layer_ConcatLayer] = &Deserializer::ParseConcat; m_ParserFunctions[Layer_ConstantLayer] = &Deserializer::ParseConstant; m_ParserFunctions[Layer_Convolution2dLayer] = &Deserializer::ParseConvolution2d; m_ParserFunctions[Layer_DepthwiseConvolution2dLayer] = &Deserializer::ParseDepthwiseConvolution2d; @@ -241,6 +242,8 @@ Deserializer::LayerBaseRawPtr Deserializer::GetBaseLayer(const GraphPtr& graphPt return graphPtr->layers()->Get(layerIndex)->layer_as_BatchToSpaceNdLayer()->base(); case Layer::Layer_BatchNormalizationLayer: return graphPtr->layers()->Get(layerIndex)->layer_as_BatchNormalizationLayer()->base(); + case Layer::Layer_ConcatLayer: + return graphPtr->layers()->Get(layerIndex)->layer_as_ConcatLayer()->base(); case Layer::Layer_ConstantLayer: return graphPtr->layers()->Get(layerIndex)->layer_as_ConstantLayer()->base(); case Layer::Layer_Convolution2dLayer: @@ -1229,6 +1232,22 @@ void Deserializer::ParseMaximum(GraphPtr graph, unsigned int layerIndex) RegisterOutputSlots(graph, layerIndex, layer); } +const armnnSerializer::OriginsDescriptor* GetOriginsDescriptor(const armnnSerializer::SerializedGraph* graph, + unsigned int layerIndex) +{ + auto layerType = graph->layers()->Get(layerIndex)->layer_type(); + + switch (layerType) + { + case Layer::Layer_ConcatLayer: + return graph->layers()->Get(layerIndex)->layer_as_ConcatLayer()->descriptor(); + case Layer::Layer_MergerLayer: + return graph->layers()->Get(layerIndex)->layer_as_MergerLayer()->descriptor(); + default: + throw armnn::Exception("unknown layer type, should be concat or merger"); + } +} + void Deserializer::ParseConcat(GraphPtr graph, unsigned int layerIndex) { CHECK_LAYERS(graph, 0, layerIndex); @@ -1237,18 +1256,17 @@ void Deserializer::ParseConcat(GraphPtr graph, unsigned int layerIndex) auto outputs = GetOutputs(graph, layerIndex); CHECK_VALID_SIZE(outputs.size(), 1); - auto mergerLayer = graph->layers()->Get(layerIndex)->layer_as_MergerLayer(); auto layerName = GetLayerName(graph, layerIndex); - auto mergerDescriptor = mergerLayer->descriptor(); - unsigned int numViews = mergerDescriptor->numViews(); - unsigned int numDimensions = mergerDescriptor->numDimensions(); + auto originsDescriptor = GetOriginsDescriptor(graph, layerIndex); + unsigned int numViews = originsDescriptor->numViews(); + unsigned int numDimensions = originsDescriptor->numDimensions(); // can now check the number of inputs == number of views auto inputs = GetInputs(graph, layerIndex); CHECK_VALID_SIZE(inputs.size(), numViews); armnn::OriginsDescriptor descriptor(numViews, numDimensions); - auto originsPtr = mergerDescriptor->viewOrigins(); + auto originsPtr = originsDescriptor->viewOrigins(); for (unsigned int v = 0; v < numViews; ++v) { auto originPtr = originsPtr->Get(v); @@ -1258,7 +1276,7 @@ void Deserializer::ParseConcat(GraphPtr graph, unsigned int layerIndex) descriptor.SetViewOriginCoord(v, d, value); } } - descriptor.SetConcatAxis(mergerDescriptor->concatAxis()); + descriptor.SetConcatAxis(originsDescriptor->concatAxis()); IConnectableLayer* layer = m_Network->AddConcatLayer(descriptor, layerName.c_str()); armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]); |