aboutsummaryrefslogtreecommitdiff
path: root/src/armnnTfLiteParser/TfLiteParser.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnTfLiteParser/TfLiteParser.cpp')
-rw-r--r--src/armnnTfLiteParser/TfLiteParser.cpp15
1 files changed, 14 insertions, 1 deletions
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp
index 939640a5e3..da81c0a628 100644
--- a/src/armnnTfLiteParser/TfLiteParser.cpp
+++ b/src/armnnTfLiteParser/TfLiteParser.cpp
@@ -871,7 +871,7 @@ void TfLiteParser::ParseTranspose(size_t subgraphIndex, size_t operatorIndex)
CHECK_MODEL(m_Model, subgraphIndex, operatorIndex);
auto inputs = GetInputs(m_Model, subgraphIndex, operatorIndex);
- CHECK_VALID_SIZE(inputs.size(), 2);
+ CHECK_VALID_SIZE(inputs.size(), 1, 2);
auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex);
CHECK_VALID_SIZE(outputs.size(), 1);
@@ -881,6 +881,19 @@ void TfLiteParser::ParseTranspose(size_t subgraphIndex, size_t operatorIndex)
PermuteDescriptor desc;
+ if(inputs.size() == 2)
+ {
+ armnn::TensorInfo permuteTensorInfo = ToTensorInfo(inputs[1]);
+ BufferRawPtr permuteBufferPtr = GetBuffer(m_Model, inputs[1]->buffer);
+
+ std::vector<unsigned int> permuteShape(permuteTensorInfo.GetNumElements());
+ ::memcpy(permuteShape.data(), permuteBufferPtr->data.data(), permuteTensorInfo.GetNumBytes());
+
+ PermutationVector permutationVector(permuteShape.data(), permuteTensorInfo.GetNumElements());
+
+ desc = PermuteDescriptor(permutationVector);
+ }
+
layer = m_Network->AddPermuteLayer(desc, layerName.c_str());
BOOST_ASSERT(layer != nullptr);