diff options
Diffstat (limited to 'src/armnnTfParser/TfParser.cpp')
-rw-r--r-- | src/armnnTfParser/TfParser.cpp | 94 |
1 files changed, 46 insertions, 48 deletions
diff --git a/src/armnnTfParser/TfParser.cpp b/src/armnnTfParser/TfParser.cpp index c00722c4ad..d5372a598b 100644 --- a/src/armnnTfParser/TfParser.cpp +++ b/src/armnnTfParser/TfParser.cpp @@ -2304,68 +2304,73 @@ ParsedTfOperationPtr TfParser::ParsePooling2d(const tensorflow::NodeDef& nodeDef std::vector<uint32_t> ksize = ReadMandatoryNodeUint32ListAttribute(nodeDef, "ksize"); // size of pool windows Pooling2dDescriptor pooling2dDescriptor; - pooling2dDescriptor.m_PoolType = pooltype; - pooling2dDescriptor.m_PaddingMethod = PaddingMethod::Exclude; + pooling2dDescriptor.m_PoolType = pooltype; + pooling2dDescriptor.m_PaddingMethod = PaddingMethod::Exclude; pooling2dDescriptor.m_OutputShapeRounding = OutputShapeRounding::Floor; CHECK_DATA_FORMAT(nodeDef, dataFormat, "Pooling2D"); + DataLayout dataLayout = dataFormat == "NHWC" ? DataLayout::NHWC : DataLayout::NCHW; + pooling2dDescriptor.m_DataLayout = dataLayout; + DataLayoutIndexed dataLayoutIndexed(dataLayout); - if (dataFormat == "NHWC") - { - pooling2dDescriptor.m_StrideX = strides[2]; - pooling2dDescriptor.m_StrideY = strides[1]; - pooling2dDescriptor.m_PoolWidth = ksize[2]; - pooling2dDescriptor.m_PoolHeight = ksize[1]; - // Swizzles input to supported memory layout. - inputTensorInfo = armnnUtils::Permuted(inputSlot.GetTensorInfo(), NHWCToArmNN); - } - else if (dataFormat == "NCHW") - { - pooling2dDescriptor.m_StrideX = strides[3]; - pooling2dDescriptor.m_StrideY = strides[2]; - pooling2dDescriptor.m_PoolWidth = ksize[3]; - pooling2dDescriptor.m_PoolHeight = ksize[2]; - } + pooling2dDescriptor.m_StrideX = strides[dataLayoutIndexed.GetWidthIndex()]; + pooling2dDescriptor.m_StrideY = strides[dataLayoutIndexed.GetHeightIndex()]; + pooling2dDescriptor.m_PoolWidth = ksize[dataLayoutIndexed.GetWidthIndex()]; + pooling2dDescriptor.m_PoolHeight = ksize[dataLayoutIndexed.GetHeightIndex()]; - uint32_t inputHeight = inputTensorInfo.GetShape()[2]; - uint32_t inputWidth = inputTensorInfo.GetShape()[3]; + uint32_t inputHeight = inputTensorInfo.GetShape()[dataLayoutIndexed.GetHeightIndex()]; + uint32_t inputWidth = inputTensorInfo.GetShape()[dataLayoutIndexed.GetWidthIndex()]; bool padding = false; TensorInfo outputInfo; + unsigned int outputHeight = 0; + unsigned int outputWidth = 0; CHECK_PADDING_TYPE(nodeDef, paddingString); if (paddingString == "SAME") { padding = true; - outputInfo = TensorInfo({ inputTensorInfo.GetShape()[0], - inputTensorInfo.GetShape()[1], - static_cast<uint32_t>(ceil( - static_cast<float>(inputHeight) / - static_cast<float>(pooling2dDescriptor.m_StrideY))), - static_cast<uint32_t>(ceil( - static_cast<float>(inputWidth) / - static_cast<float>(pooling2dDescriptor.m_StrideX))) - }, DataType::Float32); + + outputHeight = static_cast<uint32_t>(ceil(static_cast<float>(inputHeight) / + static_cast<float>(pooling2dDescriptor.m_StrideY))); + outputWidth = static_cast<uint32_t>(ceil(static_cast<float>(inputWidth) / + static_cast<float>(pooling2dDescriptor.m_StrideX))); } else if (paddingString == "VALID") { padding = false; - outputInfo = TensorInfo({ inputTensorInfo.GetShape()[0], - inputTensorInfo.GetShape()[1], - static_cast<uint32_t>(ceil( - static_cast<float>(inputHeight - pooling2dDescriptor.m_PoolHeight + 1) / - static_cast<float>(pooling2dDescriptor.m_StrideY))), - static_cast<uint32_t>(ceil( - static_cast<float>(inputWidth - pooling2dDescriptor.m_PoolWidth + 1) / - static_cast<float>(pooling2dDescriptor.m_StrideX))) - }, DataType::Float32); + + outputHeight = static_cast<uint32_t>(ceil( + static_cast<float>(inputHeight - pooling2dDescriptor.m_PoolHeight + 1) / + static_cast<float>(pooling2dDescriptor.m_StrideY))); + outputWidth = static_cast<uint32_t>(ceil( + static_cast<float>(inputWidth - pooling2dDescriptor.m_PoolWidth + 1) / + static_cast<float>(pooling2dDescriptor.m_StrideX))); + } + + switch (dataLayout) + { + case DataLayout::NHWC: + outputInfo = TensorInfo({ inputTensorInfo.GetShape()[0], + outputHeight, + outputWidth, + inputTensorInfo.GetShape()[3] }, + DataType::Float32); + break; + case DataLayout::NCHW: + outputInfo = TensorInfo({ inputTensorInfo.GetShape()[0], + inputTensorInfo.GetShape()[1], + outputHeight, + outputWidth }, + DataType::Float32); + break; } CalcPadding(inputWidth, pooling2dDescriptor.m_PoolWidth, pooling2dDescriptor.m_StrideX, - pooling2dDescriptor.m_PadLeft, pooling2dDescriptor.m_PadRight, padding); + pooling2dDescriptor.m_PadLeft, pooling2dDescriptor.m_PadRight, padding); CalcPadding(inputHeight, pooling2dDescriptor.m_PoolHeight, pooling2dDescriptor.m_StrideY, - pooling2dDescriptor.m_PadTop, pooling2dDescriptor.m_PadBottom, padding); + pooling2dDescriptor.m_PadTop, pooling2dDescriptor.m_PadBottom, padding); IConnectableLayer* layer = m_Network->AddPooling2dLayer(pooling2dDescriptor, nodeDef.name().c_str()); @@ -2381,14 +2386,7 @@ ParsedTfOperationPtr TfParser::ParsePooling2d(const tensorflow::NodeDef& nodeDef layer->GetOutputSlot(0).SetTensorInfo(outputInfo); - if (dataFormat == "NHWC") - { - layer = SwizzleInDeswizzleOut(*m_Network, inputSlot, *layer, nodeDef.name()); - } - else - { - inputSlot.Connect(layer->GetInputSlot(0)); - } + inputSlot.Connect(layer->GetInputSlot(0)); return std::make_unique<SingleLayerParsedTfOperation>(this, nodeDef, layer); } |