diff options
Diffstat (limited to 'src/armnnTfLiteParser/TfLiteParser.cpp')
-rw-r--r-- | src/armnnTfLiteParser/TfLiteParser.cpp | 15 |
1 files changed, 14 insertions, 1 deletions
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp index 939640a5e3..da81c0a628 100644 --- a/src/armnnTfLiteParser/TfLiteParser.cpp +++ b/src/armnnTfLiteParser/TfLiteParser.cpp @@ -871,7 +871,7 @@ void TfLiteParser::ParseTranspose(size_t subgraphIndex, size_t operatorIndex) CHECK_MODEL(m_Model, subgraphIndex, operatorIndex); auto inputs = GetInputs(m_Model, subgraphIndex, operatorIndex); - CHECK_VALID_SIZE(inputs.size(), 2); + CHECK_VALID_SIZE(inputs.size(), 1, 2); auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex); CHECK_VALID_SIZE(outputs.size(), 1); @@ -881,6 +881,19 @@ void TfLiteParser::ParseTranspose(size_t subgraphIndex, size_t operatorIndex) PermuteDescriptor desc; + if(inputs.size() == 2) + { + armnn::TensorInfo permuteTensorInfo = ToTensorInfo(inputs[1]); + BufferRawPtr permuteBufferPtr = GetBuffer(m_Model, inputs[1]->buffer); + + std::vector<unsigned int> permuteShape(permuteTensorInfo.GetNumElements()); + ::memcpy(permuteShape.data(), permuteBufferPtr->data.data(), permuteTensorInfo.GetNumBytes()); + + PermutationVector permutationVector(permuteShape.data(), permuteTensorInfo.GetNumElements()); + + desc = PermuteDescriptor(permutationVector); + } + layer = m_Network->AddPermuteLayer(desc, layerName.c_str()); BOOST_ASSERT(layer != nullptr); |