aboutsummaryrefslogtreecommitdiff
path: root/src/armnnTfLiteParser/TfLiteParser.cpp
diff options
context:
space:
mode:
authorKevin May <kevin.may@arm.com>2019-09-27 17:21:06 +0100
committerKevin May <kevin.may@arm.com>2019-09-27 17:21:06 +0100
commit85d9260b769bdad8ffde37546837cc206ac8ee14 (patch)
tree1d38d73821f6ab84331e1471d79c6f462f50dcc3 /src/armnnTfLiteParser/TfLiteParser.cpp
parentdfa1477f8144c8ae933f949f0cc6ab70b6ba372d (diff)
downloadarmnn-85d9260b769bdad8ffde37546837cc206ac8ee14.tar.gz
IVGCVSW-3909 Fix Transpose perm vector not parsed by Tflite parser
* Add permute vector to descriptor if present * Refactor test to check with and without permute vector Signed-off-by: Kevin May <kevin.may@arm.com> Change-Id: Ic8d882bb0f982fd00bb2854c18ea316b1b2cde2b
Diffstat (limited to 'src/armnnTfLiteParser/TfLiteParser.cpp')
-rw-r--r--src/armnnTfLiteParser/TfLiteParser.cpp15
1 files changed, 14 insertions, 1 deletions
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp
index 939640a5e3..da81c0a628 100644
--- a/src/armnnTfLiteParser/TfLiteParser.cpp
+++ b/src/armnnTfLiteParser/TfLiteParser.cpp
@@ -871,7 +871,7 @@ void TfLiteParser::ParseTranspose(size_t subgraphIndex, size_t operatorIndex)
CHECK_MODEL(m_Model, subgraphIndex, operatorIndex);
auto inputs = GetInputs(m_Model, subgraphIndex, operatorIndex);
- CHECK_VALID_SIZE(inputs.size(), 2);
+ CHECK_VALID_SIZE(inputs.size(), 1, 2);
auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex);
CHECK_VALID_SIZE(outputs.size(), 1);
@@ -881,6 +881,19 @@ void TfLiteParser::ParseTranspose(size_t subgraphIndex, size_t operatorIndex)
PermuteDescriptor desc;
+ if(inputs.size() == 2)
+ {
+ armnn::TensorInfo permuteTensorInfo = ToTensorInfo(inputs[1]);
+ BufferRawPtr permuteBufferPtr = GetBuffer(m_Model, inputs[1]->buffer);
+
+ std::vector<unsigned int> permuteShape(permuteTensorInfo.GetNumElements());
+ ::memcpy(permuteShape.data(), permuteBufferPtr->data.data(), permuteTensorInfo.GetNumBytes());
+
+ PermutationVector permutationVector(permuteShape.data(), permuteTensorInfo.GetNumElements());
+
+ desc = PermuteDescriptor(permutationVector);
+ }
+
layer = m_Network->AddPermuteLayer(desc, layerName.c_str());
BOOST_ASSERT(layer != nullptr);