diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/armnnTfLiteParser/TfLiteParser.cpp | 19 | ||||
-rwxr-xr-x | src/armnnTfParser/TfParser.cpp | 16 |
2 files changed, 9 insertions, 26 deletions
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp index f5c01f249a..56b59a115f 100644 --- a/src/armnnTfLiteParser/TfLiteParser.cpp +++ b/src/armnnTfLiteParser/TfLiteParser.cpp @@ -1011,7 +1011,7 @@ void TfLiteParser::ParseTranspose(size_t subgraphIndex, size_t operatorIndex) armnn::IConnectableLayer* layer = nullptr; auto layerName = boost::str(boost::format("Transpose:%1%:%2%") % subgraphIndex % operatorIndex); - PermuteDescriptor desc; + TransposeDescriptor desc; if (inputs.size() == 2) { @@ -1020,23 +1020,12 @@ void TfLiteParser::ParseTranspose(size_t subgraphIndex, size_t operatorIndex) auto numPermVecElements = permuteTensorInfo.GetNumElements(); std::vector<unsigned int> permuteShape(numPermVecElements); ::memcpy(permuteShape.data(), permuteBufferPtr->data.data(), permuteTensorInfo.GetNumBytes()); + PermutationVector permutationVector(permuteShape.data(), permuteTensorInfo.GetNumElements()); - // permuteShape assumes Tf/Np permute vectors, we must translate to armnn expected form - // to do so we find the perm vector which would invert what a tf perm vector would do (ex 3,0,1,2 -> 1,2,3,0) - std::vector<unsigned int> armnnPermuteShape(numPermVecElements); - std::vector<unsigned int>::iterator it; - for (unsigned int i = 0u; i < numPermVecElements; ++i) - { - it = std::find(permuteShape.begin(), permuteShape.end(), i); - armnnPermuteShape[i] = static_cast<unsigned int>(std::distance(permuteShape.begin(), it)); - } - - PermutationVector permutationVector(armnnPermuteShape.data(), permuteTensorInfo.GetNumElements()); - - desc = PermuteDescriptor(permutationVector); + desc = TransposeDescriptor(permutationVector); } - layer = m_Network->AddPermuteLayer(desc, layerName.c_str()); + layer = m_Network->AddTransposeLayer(desc, layerName.c_str()); BOOST_ASSERT(layer != nullptr); diff --git a/src/armnnTfParser/TfParser.cpp b/src/armnnTfParser/TfParser.cpp index 124c5fdcc7..13833314fd 100755 --- a/src/armnnTfParser/TfParser.cpp +++ b/src/armnnTfParser/TfParser.cpp @@ -10,6 +10,7 @@ #include <armnnUtils/Permute.hpp> #include <armnnUtils/DataLayoutIndexed.hpp> +#include <armnnUtils/Transpose.hpp> #include <GraphTopologicalSort.hpp> #include <ParserHelper.hpp> @@ -2084,26 +2085,19 @@ ParsedTfOperationPtr TfParser::ParseTranspose(const tensorflow::NodeDef& nodeDef std::vector<int32_t> permuteVectorData; permuteVectorInput->GetConstTensor(permuteVectorData); - std::vector<unsigned int> armnnPermuteVectorData(permuteVectorData.size()); - std::vector<int32_t>::iterator it; - - for (unsigned int i = 0u; i < permuteVectorData.size(); ++i) - { - it = std::find(permuteVectorData.begin(), permuteVectorData.end(), i); - armnnPermuteVectorData[i] = static_cast<unsigned int>(std::distance(permuteVectorData.begin(), it)); - } + std::vector<unsigned int> armnnPermuteVectorData(permuteVectorData.begin(), permuteVectorData.end()); const auto permutationVector = PermutationVector(armnnPermuteVectorData.data(), permuteVectorInfo.GetNumElements()); - const auto desc = PermuteDescriptor(permutationVector); + const auto desc = TransposeDescriptor(permutationVector); - auto* layer = m_Network->AddPermuteLayer(desc, nodeDef.name().c_str()); + auto* layer = m_Network->AddTransposeLayer(desc, nodeDef.name().c_str()); BOOST_ASSERT(layer); input0Slot->Connect(layer->GetInputSlot(0)); const auto& input0Info = input0Slot->GetTensorInfo(); armnn::TensorInfo outputInfo {input0Info}; - outputInfo.SetShape(armnnUtils::Permuted(input0Info.GetShape(), desc.m_DimMappings)); + outputInfo.SetShape(armnnUtils::TransposeTensorShape(input0Info.GetShape(), desc.m_DimMappings)); layer->GetOutputSlot(0).SetTensorInfo(outputInfo); return std::make_unique<SingleLayerParsedTfOperation>(this, nodeDef, layer); |