From dd3f71b64072c44cec65a7a883d0c3a29659645c Mon Sep 17 00:00:00 2001 From: Sang-Hoon Park Date: Tue, 18 Feb 2020 11:27:35 +0000 Subject: COMPMID-3060: Add TF Parser support for Transpose Signed-off-by: Sang-Hoon Park Change-Id: I9661787071554b38c5b0ab3c98431f3863b98520 --- src/armnnTfParser/TfParser.cpp | 57 +++++++++++++++++++++++++++++++++++++++++- 1 file changed, 56 insertions(+), 1 deletion(-) (limited to 'src/armnnTfParser/TfParser.cpp') diff --git a/src/armnnTfParser/TfParser.cpp b/src/armnnTfParser/TfParser.cpp index b5a421145a..124c5fdcc7 100755 --- a/src/armnnTfParser/TfParser.cpp +++ b/src/armnnTfParser/TfParser.cpp @@ -378,7 +378,8 @@ const std::map TfParser::ms_Ope { "Pad", &TfParser::ParsePad }, { "Sub", &TfParser::ParseSub }, { "Pack" , &TfParser::ParseStack }, - { "Stack", &TfParser::ParseStack } + { "Stack", &TfParser::ParseStack }, + { "Transpose", &TfParser::ParseTranspose }, }; const std::list TfParser::m_ControlInputs = { @@ -2054,6 +2055,60 @@ ParsedTfOperationPtr TfParser::ParseStack(const tensorflow::NodeDef& nodeDef, co return std::make_unique(this, nodeDef, layer); } +ParsedTfOperationPtr TfParser::ParseTranspose(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef) +{ + boost::ignore_unused(graphDef); + + auto inputs = GetInputParsedTfOperationsChecked(nodeDef, 2); + const auto inputCount = inputs.size(); + + if (inputCount != 2) + { + throw ParseException( + boost::str( + boost::format( + "The number of given input is %1%. It should be two for Transpose op." + "Node %2% %3%") + % inputCount + % nodeDef.name() + % CHECK_LOCATION().AsString())); + } + + auto* input0Slot = &inputs[0].m_IndexedValue->ResolveArmnnOutputSlot(inputs[0].m_Index); + + const auto constInput = inputs[GetConstInputIndex(inputs)]; + auto* permuteVectorInput = + boost::polymorphic_downcast*>(constInput.m_IndexedValue); + const auto& permuteVectorInfo = permuteVectorInput->GetTensorInfo(); + + std::vector permuteVectorData; + permuteVectorInput->GetConstTensor(permuteVectorData); + + std::vector armnnPermuteVectorData(permuteVectorData.size()); + std::vector::iterator it; + + for (unsigned int i = 0u; i < permuteVectorData.size(); ++i) + { + it = std::find(permuteVectorData.begin(), permuteVectorData.end(), i); + armnnPermuteVectorData[i] = static_cast(std::distance(permuteVectorData.begin(), it)); + } + + const auto permutationVector = PermutationVector(armnnPermuteVectorData.data(), permuteVectorInfo.GetNumElements()); + const auto desc = PermuteDescriptor(permutationVector); + + auto* layer = m_Network->AddPermuteLayer(desc, nodeDef.name().c_str()); + BOOST_ASSERT(layer); + + input0Slot->Connect(layer->GetInputSlot(0)); + + const auto& input0Info = input0Slot->GetTensorInfo(); + armnn::TensorInfo outputInfo {input0Info}; + outputInfo.SetShape(armnnUtils::Permuted(input0Info.GetShape(), desc.m_DimMappings)); + layer->GetOutputSlot(0).SetTensorInfo(outputInfo); + + return std::make_unique(this, nodeDef, layer); +} + unsigned int CheckPaddingTensor(const ConstTensor& paddingTensor, const TensorInfo& inputTensorInfo, const std::string& nodeName) -- cgit v1.2.1