aboutsummaryrefslogtreecommitdiff
path: root/src/armnnDeserializer/Deserializer.cpp
diff options
context:
space:
mode:
authorJim Flynn <jim.flynn@arm.com>2019-05-22 14:24:13 +0100
committerJim Flynn <jim.flynn@arm.com>2019-05-28 17:50:33 +0100
commite242f2dc646f41e9162aaaf74e057ce39fcb92df (patch)
treed6c49b559c34d1d306b1e901501dded1c18f71c5 /src/armnnDeserializer/Deserializer.cpp
parent2f2778f36e59537bbd47fb8b21e73c6c5a949584 (diff)
downloadarmnn-e242f2dc646f41e9162aaaf74e057ce39fcb92df.tar.gz
IVGCVSW-3119 Rename MergerLayer to ConcatLayer
!android-nn-driver:1210 Change-Id: I940b3b9e421c92bfd55ae996f7bc54ac077f2604 Signed-off-by: Jim Flynn <jim.flynn@arm.com>
Diffstat (limited to 'src/armnnDeserializer/Deserializer.cpp')
-rw-r--r--src/armnnDeserializer/Deserializer.cpp30
1 files changed, 24 insertions, 6 deletions
diff --git a/src/armnnDeserializer/Deserializer.cpp b/src/armnnDeserializer/Deserializer.cpp
index 14cf232cdb..75c258b7ab 100644
--- a/src/armnnDeserializer/Deserializer.cpp
+++ b/src/armnnDeserializer/Deserializer.cpp
@@ -192,6 +192,7 @@ m_ParserFunctions(Layer_MAX+1, &Deserializer::ParseUnsupportedLayer)
m_ParserFunctions[Layer_AdditionLayer] = &Deserializer::ParseAdd;
m_ParserFunctions[Layer_BatchToSpaceNdLayer] = &Deserializer::ParseBatchToSpaceNd;
m_ParserFunctions[Layer_BatchNormalizationLayer] = &Deserializer::ParseBatchNormalization;
+ m_ParserFunctions[Layer_ConcatLayer] = &Deserializer::ParseConcat;
m_ParserFunctions[Layer_ConstantLayer] = &Deserializer::ParseConstant;
m_ParserFunctions[Layer_Convolution2dLayer] = &Deserializer::ParseConvolution2d;
m_ParserFunctions[Layer_DepthwiseConvolution2dLayer] = &Deserializer::ParseDepthwiseConvolution2d;
@@ -241,6 +242,8 @@ Deserializer::LayerBaseRawPtr Deserializer::GetBaseLayer(const GraphPtr& graphPt
return graphPtr->layers()->Get(layerIndex)->layer_as_BatchToSpaceNdLayer()->base();
case Layer::Layer_BatchNormalizationLayer:
return graphPtr->layers()->Get(layerIndex)->layer_as_BatchNormalizationLayer()->base();
+ case Layer::Layer_ConcatLayer:
+ return graphPtr->layers()->Get(layerIndex)->layer_as_ConcatLayer()->base();
case Layer::Layer_ConstantLayer:
return graphPtr->layers()->Get(layerIndex)->layer_as_ConstantLayer()->base();
case Layer::Layer_Convolution2dLayer:
@@ -1229,6 +1232,22 @@ void Deserializer::ParseMaximum(GraphPtr graph, unsigned int layerIndex)
RegisterOutputSlots(graph, layerIndex, layer);
}
+const armnnSerializer::OriginsDescriptor* GetOriginsDescriptor(const armnnSerializer::SerializedGraph* graph,
+ unsigned int layerIndex)
+{
+ auto layerType = graph->layers()->Get(layerIndex)->layer_type();
+
+ switch (layerType)
+ {
+ case Layer::Layer_ConcatLayer:
+ return graph->layers()->Get(layerIndex)->layer_as_ConcatLayer()->descriptor();
+ case Layer::Layer_MergerLayer:
+ return graph->layers()->Get(layerIndex)->layer_as_MergerLayer()->descriptor();
+ default:
+ throw armnn::Exception("unknown layer type, should be concat or merger");
+ }
+}
+
void Deserializer::ParseConcat(GraphPtr graph, unsigned int layerIndex)
{
CHECK_LAYERS(graph, 0, layerIndex);
@@ -1237,18 +1256,17 @@ void Deserializer::ParseConcat(GraphPtr graph, unsigned int layerIndex)
auto outputs = GetOutputs(graph, layerIndex);
CHECK_VALID_SIZE(outputs.size(), 1);
- auto mergerLayer = graph->layers()->Get(layerIndex)->layer_as_MergerLayer();
auto layerName = GetLayerName(graph, layerIndex);
- auto mergerDescriptor = mergerLayer->descriptor();
- unsigned int numViews = mergerDescriptor->numViews();
- unsigned int numDimensions = mergerDescriptor->numDimensions();
+ auto originsDescriptor = GetOriginsDescriptor(graph, layerIndex);
+ unsigned int numViews = originsDescriptor->numViews();
+ unsigned int numDimensions = originsDescriptor->numDimensions();
// can now check the number of inputs == number of views
auto inputs = GetInputs(graph, layerIndex);
CHECK_VALID_SIZE(inputs.size(), numViews);
armnn::OriginsDescriptor descriptor(numViews, numDimensions);
- auto originsPtr = mergerDescriptor->viewOrigins();
+ auto originsPtr = originsDescriptor->viewOrigins();
for (unsigned int v = 0; v < numViews; ++v)
{
auto originPtr = originsPtr->Get(v);
@@ -1258,7 +1276,7 @@ void Deserializer::ParseConcat(GraphPtr graph, unsigned int layerIndex)
descriptor.SetViewOriginCoord(v, d, value);
}
}
- descriptor.SetConcatAxis(mergerDescriptor->concatAxis());
+ descriptor.SetConcatAxis(originsDescriptor->concatAxis());
IConnectableLayer* layer = m_Network->AddConcatLayer(descriptor, layerName.c_str());
armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]);