aboutsummaryrefslogtreecommitdiff
path: root/src/armnnTfParser/TfParser.cpp
diff options
context:
space:
mode:
authorMike Kelly <mike.kelly@arm.com>2020-03-02 11:41:31 +0000
committermike.kelly <mike.kelly@arm.com>2020-03-03 10:40:38 +0000
commit08759e26972290e6e3ac16289be24984999c70a4 (patch)
tree34f65e34a028b1150299631a080b20b0fe5a267b /src/armnnTfParser/TfParser.cpp
parentb015e5db20e29825caefa05d828fbeed73119b19 (diff)
downloadarmnn-08759e26972290e6e3ac16289be24984999c70a4.tar.gz
IVGCVSW-4375 Add parser support for Transpose
* Changed TfParser::ParseTranspose to use Transpose instead of Permute * Changed TfLiteParser::ParseTranspose to use Transpose instead of Permute !armnn:2787 Signed-off-by: Mike Kelly <mike.kelly@arm.com> Change-Id: If48f2fb88d97d31d66b6b1e631b41637d8e4c8f0
Diffstat (limited to 'src/armnnTfParser/TfParser.cpp')
-rwxr-xr-xsrc/armnnTfParser/TfParser.cpp16
1 files changed, 5 insertions, 11 deletions
diff --git a/src/armnnTfParser/TfParser.cpp b/src/armnnTfParser/TfParser.cpp
index 124c5fdcc7..13833314fd 100755
--- a/src/armnnTfParser/TfParser.cpp
+++ b/src/armnnTfParser/TfParser.cpp
@@ -10,6 +10,7 @@
#include <armnnUtils/Permute.hpp>
#include <armnnUtils/DataLayoutIndexed.hpp>
+#include <armnnUtils/Transpose.hpp>
#include <GraphTopologicalSort.hpp>
#include <ParserHelper.hpp>
@@ -2084,26 +2085,19 @@ ParsedTfOperationPtr TfParser::ParseTranspose(const tensorflow::NodeDef& nodeDef
std::vector<int32_t> permuteVectorData;
permuteVectorInput->GetConstTensor(permuteVectorData);
- std::vector<unsigned int> armnnPermuteVectorData(permuteVectorData.size());
- std::vector<int32_t>::iterator it;
-
- for (unsigned int i = 0u; i < permuteVectorData.size(); ++i)
- {
- it = std::find(permuteVectorData.begin(), permuteVectorData.end(), i);
- armnnPermuteVectorData[i] = static_cast<unsigned int>(std::distance(permuteVectorData.begin(), it));
- }
+ std::vector<unsigned int> armnnPermuteVectorData(permuteVectorData.begin(), permuteVectorData.end());
const auto permutationVector = PermutationVector(armnnPermuteVectorData.data(), permuteVectorInfo.GetNumElements());
- const auto desc = PermuteDescriptor(permutationVector);
+ const auto desc = TransposeDescriptor(permutationVector);
- auto* layer = m_Network->AddPermuteLayer(desc, nodeDef.name().c_str());
+ auto* layer = m_Network->AddTransposeLayer(desc, nodeDef.name().c_str());
BOOST_ASSERT(layer);
input0Slot->Connect(layer->GetInputSlot(0));
const auto& input0Info = input0Slot->GetTensorInfo();
armnn::TensorInfo outputInfo {input0Info};
- outputInfo.SetShape(armnnUtils::Permuted(input0Info.GetShape(), desc.m_DimMappings));
+ outputInfo.SetShape(armnnUtils::TransposeTensorShape(input0Info.GetShape(), desc.m_DimMappings));
layer->GetOutputSlot(0).SetTensorInfo(outputInfo);
return std::make_unique<SingleLayerParsedTfOperation>(this, nodeDef, layer);