From aa920c56838c2a0b31bd4e3c54bd57ff2f20969e Mon Sep 17 00:00:00 2001 From: Tee Jung Date: Tue, 5 Nov 2019 10:48:25 +0000 Subject: Build graph->inputIds/outputIds with layerBindingId instead of layerIndex Signed-off-by: Jung Tae-young tee.ty.jung@openedges.com Signed-off-by: Matteo Martincigh Change-Id: I25ceeca70e72fad88ab039aed5a5ab6a7cc08c6c Signed-off-by: Derek Lamberti --- src/armnnDeserializer/Deserializer.cpp | 74 ++++++++++++++++++++++++++++++++-- 1 file changed, 70 insertions(+), 4 deletions(-) (limited to 'src/armnnDeserializer/Deserializer.cpp') diff --git a/src/armnnDeserializer/Deserializer.cpp b/src/armnnDeserializer/Deserializer.cpp index 99ee0b5b2d..3bbd71a972 100644 --- a/src/armnnDeserializer/Deserializer.cpp +++ b/src/armnnDeserializer/Deserializer.cpp @@ -768,6 +768,40 @@ BindingPointInfo Deserializer::GetNetworkOutputBindingInfo(unsigned int layerInd CHECK_LOCATION().AsString())); } +unsigned int Deserializer::GetInputLayerInVector(GraphPtr graph, int targetId) +{ + for (unsigned int i = 0; i < graph->layers()->size(); i++) + { + auto layer = graph->layers()->Get(i); + if (layer->layer_type() == Layer::Layer_InputLayer) + { + auto layerBindingId = layer->layer_as_InputLayer()->base()->layerBindingId(); + if (layerBindingId == targetId) + { + return i; + } + } + } + throw ParseException("Input layer with given layerBindingId not found"); +} + +unsigned int Deserializer::GetOutputLayerInVector(GraphPtr graph, int targetId) +{ + for (unsigned int i = 0; i < graph->layers()->size(); i++) + { + auto layer = graph->layers()->Get(i); + if (layer->layer_type() == Layer::Layer_OutputLayer) + { + auto layerBindingId = layer->layer_as_OutputLayer()->base()->layerBindingId(); + if (layerBindingId == targetId) + { + return i; + } + } + } + throw ParseException("Output layer with given layerBindingId not found"); +} + unsigned int Deserializer::GetLayerIndexInVector(GraphPtr graph, unsigned int targetIndex) { for (unsigned int i = 0; i < graph->layers()->size(); i++) @@ -781,6 +815,18 @@ unsigned int Deserializer::GetLayerIndexInVector(GraphPtr graph, unsigned int ta throw ParseException("Layer with given index not found"); } +Deserializer::FeatureVersions Deserializer::GetFeatureVersions(GraphPtr graph) +{ + Deserializer::FeatureVersions versions; + + if (graph->featureVersions()) + { + versions.m_BindingIdScheme = graph->featureVersions()->bindingIdsScheme(); + } + + return versions; +} + void Deserializer::SetupInputLayers(GraphPtr graph) { CHECK_GRAPH(graph, 0); @@ -790,8 +836,18 @@ void Deserializer::SetupInputLayers(GraphPtr graph) for (unsigned int i = 0; i < numInputs; i++) { - const unsigned int inputId = graph->inputIds()->Get(i); - const unsigned int inputLayerIndex = GetLayerIndexInVector(graph, inputId); + unsigned int inputLayerIndex = 0xFFFFFFFF; + if (GetFeatureVersions(graph).m_BindingIdScheme == 0) + { + const unsigned int inputId = boost::numeric_cast(graph->inputIds()->Get(i)); + inputLayerIndex = GetLayerIndexInVector(graph, inputId); + } + else + { + const int inputId = graph->inputIds()->Get(i); + inputLayerIndex = GetInputLayerInVector(graph, inputId); + } + LayerBaseRawPtr baseLayer = GetBaseLayer(graph, inputLayerIndex); // GetBindingLayerInfo expect the index to be index in the vector not index property on each layer base @@ -819,8 +875,18 @@ void Deserializer::SetupOutputLayers(GraphPtr graph) for (unsigned int i = 0; i < numOutputs; i++) { - const unsigned int outputId = graph->outputIds()->Get(i); - const unsigned int outputLayerIndex = GetLayerIndexInVector(graph, outputId); + unsigned int outputLayerIndex = 0xFFFFFFFF; + if (GetFeatureVersions(graph).m_BindingIdScheme == 0) + { + const unsigned int outputId = boost::numeric_cast(graph->outputIds()->Get(i)); + outputLayerIndex = GetLayerIndexInVector(graph, outputId); + } + else + { + const int outputId = graph->outputIds()->Get(i); + outputLayerIndex = GetOutputLayerInVector(graph, outputId); + } + LayerBaseRawPtr baseLayer = GetBaseLayer(graph, outputLayerIndex); // GetBindingLayerInfo expect the index to be index in the vector not index property on each layer base -- cgit v1.2.1