diff options
author | Kevin May <kevin.may@arm.com> | 2019-09-27 17:21:06 +0100 |
---|---|---|
committer | Kevin May <kevin.may@arm.com> | 2019-09-27 17:21:06 +0100 |
commit | 85d9260b769bdad8ffde37546837cc206ac8ee14 (patch) | |
tree | 1d38d73821f6ab84331e1471d79c6f462f50dcc3 /src/armnnTfLiteParser/TfLiteParser.cpp | |
parent | dfa1477f8144c8ae933f949f0cc6ab70b6ba372d (diff) | |
download | armnn-85d9260b769bdad8ffde37546837cc206ac8ee14.tar.gz |
IVGCVSW-3909 Fix Transpose perm vector not parsed by Tflite parser
* Add permute vector to descriptor if present
* Refactor test to check with and without permute vector
Signed-off-by: Kevin May <kevin.may@arm.com>
Change-Id: Ic8d882bb0f982fd00bb2854c18ea316b1b2cde2b
Diffstat (limited to 'src/armnnTfLiteParser/TfLiteParser.cpp')
-rw-r--r-- | src/armnnTfLiteParser/TfLiteParser.cpp | 15 |
1 files changed, 14 insertions, 1 deletions
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp index 939640a5e3..da81c0a628 100644 --- a/src/armnnTfLiteParser/TfLiteParser.cpp +++ b/src/armnnTfLiteParser/TfLiteParser.cpp @@ -871,7 +871,7 @@ void TfLiteParser::ParseTranspose(size_t subgraphIndex, size_t operatorIndex) CHECK_MODEL(m_Model, subgraphIndex, operatorIndex); auto inputs = GetInputs(m_Model, subgraphIndex, operatorIndex); - CHECK_VALID_SIZE(inputs.size(), 2); + CHECK_VALID_SIZE(inputs.size(), 1, 2); auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex); CHECK_VALID_SIZE(outputs.size(), 1); @@ -881,6 +881,19 @@ void TfLiteParser::ParseTranspose(size_t subgraphIndex, size_t operatorIndex) PermuteDescriptor desc; + if(inputs.size() == 2) + { + armnn::TensorInfo permuteTensorInfo = ToTensorInfo(inputs[1]); + BufferRawPtr permuteBufferPtr = GetBuffer(m_Model, inputs[1]->buffer); + + std::vector<unsigned int> permuteShape(permuteTensorInfo.GetNumElements()); + ::memcpy(permuteShape.data(), permuteBufferPtr->data.data(), permuteTensorInfo.GetNumBytes()); + + PermutationVector permutationVector(permuteShape.data(), permuteTensorInfo.GetNumElements()); + + desc = PermuteDescriptor(permutationVector); + } + layer = m_Network->AddPermuteLayer(desc, layerName.c_str()); BOOST_ASSERT(layer != nullptr); |