diff options
author | Mike Kelly <mike.kelly@arm.com> | 2020-03-02 11:41:31 +0000 |
---|---|---|
committer | mike.kelly <mike.kelly@arm.com> | 2020-03-03 10:40:38 +0000 |
commit | 08759e26972290e6e3ac16289be24984999c70a4 (patch) | |
tree | 34f65e34a028b1150299631a080b20b0fe5a267b /src/armnnTfParser/TfParser.cpp | |
parent | b015e5db20e29825caefa05d828fbeed73119b19 (diff) | |
download | armnn-08759e26972290e6e3ac16289be24984999c70a4.tar.gz |
IVGCVSW-4375 Add parser support for Transpose
* Changed TfParser::ParseTranspose to use Transpose instead of Permute
* Changed TfLiteParser::ParseTranspose to use Transpose instead of Permute
!armnn:2787
Signed-off-by: Mike Kelly <mike.kelly@arm.com>
Change-Id: If48f2fb88d97d31d66b6b1e631b41637d8e4c8f0
Diffstat (limited to 'src/armnnTfParser/TfParser.cpp')
-rwxr-xr-x | src/armnnTfParser/TfParser.cpp | 16 |
1 files changed, 5 insertions, 11 deletions
diff --git a/src/armnnTfParser/TfParser.cpp b/src/armnnTfParser/TfParser.cpp index 124c5fdcc7..13833314fd 100755 --- a/src/armnnTfParser/TfParser.cpp +++ b/src/armnnTfParser/TfParser.cpp @@ -10,6 +10,7 @@ #include <armnnUtils/Permute.hpp> #include <armnnUtils/DataLayoutIndexed.hpp> +#include <armnnUtils/Transpose.hpp> #include <GraphTopologicalSort.hpp> #include <ParserHelper.hpp> @@ -2084,26 +2085,19 @@ ParsedTfOperationPtr TfParser::ParseTranspose(const tensorflow::NodeDef& nodeDef std::vector<int32_t> permuteVectorData; permuteVectorInput->GetConstTensor(permuteVectorData); - std::vector<unsigned int> armnnPermuteVectorData(permuteVectorData.size()); - std::vector<int32_t>::iterator it; - - for (unsigned int i = 0u; i < permuteVectorData.size(); ++i) - { - it = std::find(permuteVectorData.begin(), permuteVectorData.end(), i); - armnnPermuteVectorData[i] = static_cast<unsigned int>(std::distance(permuteVectorData.begin(), it)); - } + std::vector<unsigned int> armnnPermuteVectorData(permuteVectorData.begin(), permuteVectorData.end()); const auto permutationVector = PermutationVector(armnnPermuteVectorData.data(), permuteVectorInfo.GetNumElements()); - const auto desc = PermuteDescriptor(permutationVector); + const auto desc = TransposeDescriptor(permutationVector); - auto* layer = m_Network->AddPermuteLayer(desc, nodeDef.name().c_str()); + auto* layer = m_Network->AddTransposeLayer(desc, nodeDef.name().c_str()); BOOST_ASSERT(layer); input0Slot->Connect(layer->GetInputSlot(0)); const auto& input0Info = input0Slot->GetTensorInfo(); armnn::TensorInfo outputInfo {input0Info}; - outputInfo.SetShape(armnnUtils::Permuted(input0Info.GetShape(), desc.m_DimMappings)); + outputInfo.SetShape(armnnUtils::TransposeTensorShape(input0Info.GetShape(), desc.m_DimMappings)); layer->GetOutputSlot(0).SetTensorInfo(outputInfo); return std::make_unique<SingleLayerParsedTfOperation>(this, nodeDef, layer); |