aboutsummaryrefslogtreecommitdiff
path: root/src/armnnTfLiteParser
diff options
context:
space:
mode:
authorMike Kelly <mike.kelly@arm.com>2020-03-02 11:41:31 +0000
committermike.kelly <mike.kelly@arm.com>2020-03-03 10:40:38 +0000
commit08759e26972290e6e3ac16289be24984999c70a4 (patch)
tree34f65e34a028b1150299631a080b20b0fe5a267b /src/armnnTfLiteParser
parentb015e5db20e29825caefa05d828fbeed73119b19 (diff)
downloadarmnn-08759e26972290e6e3ac16289be24984999c70a4.tar.gz
IVGCVSW-4375 Add parser support for Transpose
* Changed TfParser::ParseTranspose to use Transpose instead of Permute * Changed TfLiteParser::ParseTranspose to use Transpose instead of Permute !armnn:2787 Signed-off-by: Mike Kelly <mike.kelly@arm.com> Change-Id: If48f2fb88d97d31d66b6b1e631b41637d8e4c8f0
Diffstat (limited to 'src/armnnTfLiteParser')
-rw-r--r--src/armnnTfLiteParser/TfLiteParser.cpp19
1 files changed, 4 insertions, 15 deletions
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp
index f5c01f249a..56b59a115f 100644
--- a/src/armnnTfLiteParser/TfLiteParser.cpp
+++ b/src/armnnTfLiteParser/TfLiteParser.cpp
@@ -1011,7 +1011,7 @@ void TfLiteParser::ParseTranspose(size_t subgraphIndex, size_t operatorIndex)
armnn::IConnectableLayer* layer = nullptr;
auto layerName = boost::str(boost::format("Transpose:%1%:%2%") % subgraphIndex % operatorIndex);
- PermuteDescriptor desc;
+ TransposeDescriptor desc;
if (inputs.size() == 2)
{
@@ -1020,23 +1020,12 @@ void TfLiteParser::ParseTranspose(size_t subgraphIndex, size_t operatorIndex)
auto numPermVecElements = permuteTensorInfo.GetNumElements();
std::vector<unsigned int> permuteShape(numPermVecElements);
::memcpy(permuteShape.data(), permuteBufferPtr->data.data(), permuteTensorInfo.GetNumBytes());
+ PermutationVector permutationVector(permuteShape.data(), permuteTensorInfo.GetNumElements());
- // permuteShape assumes Tf/Np permute vectors, we must translate to armnn expected form
- // to do so we find the perm vector which would invert what a tf perm vector would do (ex 3,0,1,2 -> 1,2,3,0)
- std::vector<unsigned int> armnnPermuteShape(numPermVecElements);
- std::vector<unsigned int>::iterator it;
- for (unsigned int i = 0u; i < numPermVecElements; ++i)
- {
- it = std::find(permuteShape.begin(), permuteShape.end(), i);
- armnnPermuteShape[i] = static_cast<unsigned int>(std::distance(permuteShape.begin(), it));
- }
-
- PermutationVector permutationVector(armnnPermuteShape.data(), permuteTensorInfo.GetNumElements());
-
- desc = PermuteDescriptor(permutationVector);
+ desc = TransposeDescriptor(permutationVector);
}
- layer = m_Network->AddPermuteLayer(desc, layerName.c_str());
+ layer = m_Network->AddTransposeLayer(desc, layerName.c_str());
BOOST_ASSERT(layer != nullptr);