diff options
Diffstat (limited to 'src/armnnDeserializer/Deserializer.cpp')
-rw-r--r-- | src/armnnDeserializer/Deserializer.cpp | 42 |
1 files changed, 42 insertions, 0 deletions
diff --git a/src/armnnDeserializer/Deserializer.cpp b/src/armnnDeserializer/Deserializer.cpp index ed110ad750..d62751d640 100644 --- a/src/armnnDeserializer/Deserializer.cpp +++ b/src/armnnDeserializer/Deserializer.cpp @@ -201,6 +201,7 @@ m_ParserFunctions(Layer_MAX+1, &Deserializer::ParseUnsupportedLayer) m_ParserFunctions[Layer_MaximumLayer] = &Deserializer::ParseMaximum; m_ParserFunctions[Layer_MeanLayer] = &Deserializer::ParseMean; m_ParserFunctions[Layer_MinimumLayer] = &Deserializer::ParseMinimum; + m_ParserFunctions[Layer_MergerLayer] = &Deserializer::ParseMerger; m_ParserFunctions[Layer_MultiplicationLayer] = &Deserializer::ParseMultiplication; m_ParserFunctions[Layer_NormalizationLayer] = &Deserializer::ParseNormalization; m_ParserFunctions[Layer_PadLayer] = &Deserializer::ParsePad; @@ -255,6 +256,8 @@ Deserializer::LayerBaseRawPtr Deserializer::GetBaseLayer(const GraphPtr& graphPt return graphPtr->layers()->Get(layerIndex)->layer_as_MinimumLayer()->base(); case Layer::Layer_MaximumLayer: return graphPtr->layers()->Get(layerIndex)->layer_as_MaximumLayer()->base(); + case Layer::Layer_MergerLayer: + return graphPtr->layers()->Get(layerIndex)->layer_as_MergerLayer()->base(); case Layer::Layer_MultiplicationLayer: return graphPtr->layers()->Get(layerIndex)->layer_as_MultiplicationLayer()->base(); case Layer::Layer_NormalizationLayer: @@ -1111,6 +1114,45 @@ void Deserializer::ParseMaximum(GraphPtr graph, unsigned int layerIndex) RegisterOutputSlots(graph, layerIndex, layer); } +void Deserializer::ParseMerger(GraphPtr graph, unsigned int layerIndex) +{ + CHECK_LAYERS(graph, 0, layerIndex); + CHECK_LOCATION(); + + auto outputs = GetOutputs(graph, layerIndex); + CHECK_VALID_SIZE(outputs.size(), 1); + + auto mergerLayer = graph->layers()->Get(layerIndex)->layer_as_MergerLayer(); + auto layerName = GetLayerName(graph, layerIndex); + auto mergerDescriptor = mergerLayer->descriptor(); + unsigned int numViews = mergerDescriptor->numViews(); + unsigned int numDimensions = mergerDescriptor->numDimensions(); + + // can now check the number of inputs == number of views + auto inputs = GetInputs(graph, layerIndex); + CHECK_VALID_SIZE(inputs.size(), numViews); + + armnn::OriginsDescriptor descriptor(numViews, numDimensions); + auto originsPtr = mergerDescriptor->viewOrigins(); + for (unsigned int v = 0; v < numViews; ++v) + { + auto originPtr = originsPtr->Get(v); + for (unsigned int d = 0; d < numDimensions; ++d) + { + uint32_t value = originPtr->data()->Get(d); + descriptor.SetViewOriginCoord(v, d, value); + } + } + descriptor.SetConcatAxis(mergerDescriptor->concatAxis()); + + IConnectableLayer* layer = m_Network->AddMergerLayer(descriptor, layerName.c_str()); + armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]); + layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo); + + RegisterInputSlots(graph, layerIndex, layer); + RegisterOutputSlots(graph, layerIndex, layer); +} + void Deserializer::ParseMultiplication(GraphPtr graph, unsigned int layerIndex) { CHECK_LAYERS(graph, 0, layerIndex); |