aboutsummaryrefslogtreecommitdiff
path: root/src/armnnDeserializeParser/DeserializeParser.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnDeserializeParser/DeserializeParser.cpp')
-rw-r--r--src/armnnDeserializeParser/DeserializeParser.cpp125
1 files changed, 120 insertions, 5 deletions
diff --git a/src/armnnDeserializeParser/DeserializeParser.cpp b/src/armnnDeserializeParser/DeserializeParser.cpp
index eb7bccaa1d..f47c23f0b5 100644
--- a/src/armnnDeserializeParser/DeserializeParser.cpp
+++ b/src/armnnDeserializeParser/DeserializeParser.cpp
@@ -136,6 +136,7 @@ m_ParserFunctions(Layer_MAX+1, &DeserializeParser::ParseUnsupportedLayer)
// register supported layers
m_ParserFunctions[Layer_AdditionLayer] = &DeserializeParser::ParseAdd;
m_ParserFunctions[Layer_MultiplicationLayer] = &DeserializeParser::ParseMultiplication;
+ m_ParserFunctions[Layer_Pooling2dLayer] = &DeserializeParser::ParsePooling2d;
m_ParserFunctions[Layer_SoftmaxLayer] = &DeserializeParser::ParseSoftmax;
}
@@ -153,6 +154,8 @@ DeserializeParser::LayerBaseRawPtr DeserializeParser::GetBaseLayer(const GraphPt
return graphPtr->layers()->Get(layerIndex)->layer_as_MultiplicationLayer()->base();
case Layer::Layer_OutputLayer:
return graphPtr->layers()->Get(layerIndex)->layer_as_OutputLayer()->base()->base();
+ case Layer::Layer_Pooling2dLayer:
+ return graphPtr->layers()->Get(layerIndex)->layer_as_Pooling2dLayer()->base();
case Layer::Layer_SoftmaxLayer:
return graphPtr->layers()->Get(layerIndex)->layer_as_SoftmaxLayer()->base();
case Layer::Layer_NONE:
@@ -356,7 +359,6 @@ DeserializeParser::GraphPtr DeserializeParser::LoadGraphFromFile(const char* fil
}
std::ifstream file(fileName, std::ios::binary);
fileContent = std::string((std::istreambuf_iterator<char>(file)), std::istreambuf_iterator<char>());
-
return LoadGraphFromBinary(reinterpret_cast<const uint8_t*>(fileContent.c_str()), fileContent.size());
}
@@ -581,8 +583,8 @@ void DeserializeParser::ParseAdd(unsigned int layerIndex)
auto outputs = GetOutputs(m_Graph, layerIndex);
CHECK_VALID_SIZE(outputs.size(), 1);
- auto layerName = boost::str(boost::format("Addition:%1%") % layerIndex);
- IConnectableLayer* layer = m_Network->AddAdditionLayer(layerName.c_str());
+ m_layerName = boost::str(boost::format("Addition:%1%") % layerIndex);
+ IConnectableLayer* layer = m_Network->AddAdditionLayer(m_layerName.c_str());
armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]);
layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
@@ -601,8 +603,8 @@ void DeserializeParser::ParseMultiplication(unsigned int layerIndex)
auto outputs = GetOutputs(m_Graph, layerIndex);
CHECK_VALID_SIZE(outputs.size(), 1);
- auto layerName = boost::str(boost::format("Multiplication:%1%") % layerIndex);
- IConnectableLayer* layer = m_Network->AddMultiplicationLayer(layerName.c_str());
+ m_layerName = boost::str(boost::format("Multiplication:%1%") % layerIndex);
+ IConnectableLayer* layer = m_Network->AddMultiplicationLayer(m_layerName.c_str());
armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]);
layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
@@ -611,6 +613,119 @@ void DeserializeParser::ParseMultiplication(unsigned int layerIndex)
RegisterOutputSlots(layerIndex, layer);
}
+armnn::Pooling2dDescriptor DeserializeParser::GetPoolingDescriptor(DeserializeParser::PoolingDescriptor pooling2dDesc,
+ unsigned int layerIndex)
+{
+ armnn::Pooling2dDescriptor desc;
+
+ switch (pooling2dDesc->poolType())
+ {
+ case PoolingAlgorithm_Average:
+ {
+ desc.m_PoolType = armnn::PoolingAlgorithm::Average;
+ m_layerName = boost::str(boost::format("AveragePool2D:%1%") % layerIndex);
+ break;
+ }
+ case PoolingAlgorithm_Max:
+ {
+ desc.m_PoolType = armnn::PoolingAlgorithm::Max;
+ m_layerName = boost::str(boost::format("MaxPool2D:%1%") % layerIndex);
+ break;
+ }
+ default:
+ {
+ BOOST_ASSERT_MSG(false, "Unsupported pooling algorithm");
+ }
+ }
+
+ switch (pooling2dDesc->outputShapeRounding())
+ {
+ case OutputShapeRounding_Floor:
+ {
+ desc.m_OutputShapeRounding = armnn::OutputShapeRounding::Floor;
+ break;
+ }
+ case OutputShapeRounding_Ceiling:
+ {
+ desc.m_OutputShapeRounding = armnn::OutputShapeRounding::Ceiling;
+ break;
+ }
+ default:
+ {
+ BOOST_ASSERT_MSG(false, "Unsupported output shape rounding");
+ }
+ }
+
+ switch (pooling2dDesc->paddingMethod())
+ {
+ case PaddingMethod_Exclude:
+ {
+ desc.m_PaddingMethod = armnn::PaddingMethod::Exclude;
+ break;
+ }
+ case PaddingMethod_IgnoreValue:
+ {
+ desc.m_PaddingMethod = armnn::PaddingMethod::IgnoreValue;
+ break;
+ }
+ default:
+ {
+ BOOST_ASSERT_MSG(false, "Unsupported padding method");
+ }
+ }
+
+ switch (pooling2dDesc->dataLayout())
+ {
+ case DataLayout_NCHW:
+ {
+ desc.m_DataLayout = armnn::DataLayout::NCHW;
+ break;
+ }
+ case DataLayout_NHWC:
+ {
+ desc.m_DataLayout = armnn::DataLayout::NHWC;
+ break;
+ }
+ default:
+ {
+ BOOST_ASSERT_MSG(false, "Unsupported data layout");
+ }
+ }
+
+ desc.m_PadRight = pooling2dDesc->padRight();
+ desc.m_PadLeft = pooling2dDesc->padLeft();
+ desc.m_PadBottom = pooling2dDesc->padBottom();
+ desc.m_PadTop = pooling2dDesc->padTop();
+ desc.m_StrideX = pooling2dDesc->strideX();
+ desc.m_StrideY = pooling2dDesc->strideY();
+ desc.m_PoolWidth = pooling2dDesc->poolWidth();
+ desc.m_PoolHeight = pooling2dDesc->poolHeight();
+
+ return desc;
+}
+
+void DeserializeParser::ParsePooling2d(unsigned int layerIndex)
+{
+ CHECK_LAYERS(m_Graph, 0, layerIndex);
+
+ auto pooling2dDes = m_Graph->layers()->Get(layerIndex)->layer_as_Pooling2dLayer()->descriptor();
+
+ auto inputs = GetInputs(m_Graph, layerIndex);
+ CHECK_VALID_SIZE(inputs.size(), 1);
+
+ auto outputs = GetOutputs(m_Graph, layerIndex);
+ CHECK_VALID_SIZE(outputs.size(), 1);
+ auto outputInfo = ToTensorInfo(outputs[0]);
+
+ auto pooling2dDescriptor = GetPoolingDescriptor(pooling2dDes, layerIndex);
+
+ IConnectableLayer* layer = m_Network->AddPooling2dLayer(pooling2dDescriptor, m_layerName.c_str());
+ layer->GetOutputSlot(0).SetTensorInfo(outputInfo);
+
+ RegisterInputSlots(layerIndex, layer);
+ RegisterOutputSlots(layerIndex, layer);
+}
+
void DeserializeParser::ParseSoftmax(unsigned int layerIndex)
{
CHECK_LAYERS(m_Graph, 0, layerIndex);