From 08759e26972290e6e3ac16289be24984999c70a4 Mon Sep 17 00:00:00 2001 From: Mike Kelly Date: Mon, 2 Mar 2020 11:41:31 +0000 Subject: IVGCVSW-4375 Add parser support for Transpose * Changed TfParser::ParseTranspose to use Transpose instead of Permute * Changed TfLiteParser::ParseTranspose to use Transpose instead of Permute !armnn:2787 Signed-off-by: Mike Kelly Change-Id: If48f2fb88d97d31d66b6b1e631b41637d8e4c8f0 --- src/armnnTfLiteParser/TfLiteParser.cpp | 19 ++++--------------- 1 file changed, 4 insertions(+), 15 deletions(-) (limited to 'src/armnnTfLiteParser') diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp index f5c01f249a..56b59a115f 100644 --- a/src/armnnTfLiteParser/TfLiteParser.cpp +++ b/src/armnnTfLiteParser/TfLiteParser.cpp @@ -1011,7 +1011,7 @@ void TfLiteParser::ParseTranspose(size_t subgraphIndex, size_t operatorIndex) armnn::IConnectableLayer* layer = nullptr; auto layerName = boost::str(boost::format("Transpose:%1%:%2%") % subgraphIndex % operatorIndex); - PermuteDescriptor desc; + TransposeDescriptor desc; if (inputs.size() == 2) { @@ -1020,23 +1020,12 @@ void TfLiteParser::ParseTranspose(size_t subgraphIndex, size_t operatorIndex) auto numPermVecElements = permuteTensorInfo.GetNumElements(); std::vector permuteShape(numPermVecElements); ::memcpy(permuteShape.data(), permuteBufferPtr->data.data(), permuteTensorInfo.GetNumBytes()); + PermutationVector permutationVector(permuteShape.data(), permuteTensorInfo.GetNumElements()); - // permuteShape assumes Tf/Np permute vectors, we must translate to armnn expected form - // to do so we find the perm vector which would invert what a tf perm vector would do (ex 3,0,1,2 -> 1,2,3,0) - std::vector armnnPermuteShape(numPermVecElements); - std::vector::iterator it; - for (unsigned int i = 0u; i < numPermVecElements; ++i) - { - it = std::find(permuteShape.begin(), permuteShape.end(), i); - armnnPermuteShape[i] = static_cast(std::distance(permuteShape.begin(), it)); - } - - PermutationVector permutationVector(armnnPermuteShape.data(), permuteTensorInfo.GetNumElements()); - - desc = PermuteDescriptor(permutationVector); + desc = TransposeDescriptor(permutationVector); } - layer = m_Network->AddPermuteLayer(desc, layerName.c_str()); + layer = m_Network->AddTransposeLayer(desc, layerName.c_str()); BOOST_ASSERT(layer != nullptr); -- cgit v1.2.1