aboutsummaryrefslogtreecommitdiff
path: root/src/armnnTfLiteParser
diff options
context:
space:
mode:
authorKevin May <kevin.may@arm.com>2019-09-27 17:21:06 +0100
committerKevin May <kevin.may@arm.com>2019-09-27 17:21:06 +0100
commit85d9260b769bdad8ffde37546837cc206ac8ee14 (patch)
tree1d38d73821f6ab84331e1471d79c6f462f50dcc3 /src/armnnTfLiteParser
parentdfa1477f8144c8ae933f949f0cc6ab70b6ba372d (diff)
downloadarmnn-85d9260b769bdad8ffde37546837cc206ac8ee14.tar.gz
IVGCVSW-3909 Fix Transpose perm vector not parsed by Tflite parser
* Add permute vector to descriptor if present * Refactor test to check with and without permute vector Signed-off-by: Kevin May <kevin.may@arm.com> Change-Id: Ic8d882bb0f982fd00bb2854c18ea316b1b2cde2b
Diffstat (limited to 'src/armnnTfLiteParser')
-rw-r--r--src/armnnTfLiteParser/TfLiteParser.cpp15
-rw-r--r--src/armnnTfLiteParser/test/Transpose.cpp93
2 files changed, 76 insertions, 32 deletions
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp
index 939640a5e3..da81c0a628 100644
--- a/src/armnnTfLiteParser/TfLiteParser.cpp
+++ b/src/armnnTfLiteParser/TfLiteParser.cpp
@@ -871,7 +871,7 @@ void TfLiteParser::ParseTranspose(size_t subgraphIndex, size_t operatorIndex)
CHECK_MODEL(m_Model, subgraphIndex, operatorIndex);
auto inputs = GetInputs(m_Model, subgraphIndex, operatorIndex);
- CHECK_VALID_SIZE(inputs.size(), 2);
+ CHECK_VALID_SIZE(inputs.size(), 1, 2);
auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex);
CHECK_VALID_SIZE(outputs.size(), 1);
@@ -881,6 +881,19 @@ void TfLiteParser::ParseTranspose(size_t subgraphIndex, size_t operatorIndex)
PermuteDescriptor desc;
+ if(inputs.size() == 2)
+ {
+ armnn::TensorInfo permuteTensorInfo = ToTensorInfo(inputs[1]);
+ BufferRawPtr permuteBufferPtr = GetBuffer(m_Model, inputs[1]->buffer);
+
+ std::vector<unsigned int> permuteShape(permuteTensorInfo.GetNumElements());
+ ::memcpy(permuteShape.data(), permuteBufferPtr->data.data(), permuteTensorInfo.GetNumBytes());
+
+ PermutationVector permutationVector(permuteShape.data(), permuteTensorInfo.GetNumElements());
+
+ desc = PermuteDescriptor(permutationVector);
+ }
+
layer = m_Network->AddPermuteLayer(desc, layerName.c_str());
BOOST_ASSERT(layer != nullptr);
diff --git a/src/armnnTfLiteParser/test/Transpose.cpp b/src/armnnTfLiteParser/test/Transpose.cpp
index 4430438bb9..2e3190b62e 100644
--- a/src/armnnTfLiteParser/test/Transpose.cpp
+++ b/src/armnnTfLiteParser/test/Transpose.cpp
@@ -12,6 +12,7 @@ BOOST_AUTO_TEST_SUITE(TensorflowLiteParser)
struct TransposeFixture : public ParserFlatbuffersFixture
{
explicit TransposeFixture(const std::string & inputShape,
+ const std::string & permuteData,
const std::string & outputShape)
{
m_JsonString = R"(
@@ -29,8 +30,8 @@ struct TransposeFixture : public ParserFlatbuffersFixture
{
"shape": )" + inputShape + R"(,
"type": "FLOAT32",
- "buffer": 3,
- "name": "Placeholder",
+ "buffer": 0,
+ "name": "inputTensor",
"quantization": {
"min": [
0.0
@@ -46,28 +47,33 @@ struct TransposeFixture : public ParserFlatbuffersFixture
{
"shape": )" + outputShape + R"(,
"type": "FLOAT32",
- "buffer": 2,
- "name": "transpose",
- "quantization": {
- "details_type": 0,
- "quantized_dimension": 0
- },
- "is_variable": false
- },
- {
- "shape": [
- 3
- ],
- "type": "INT32",
"buffer": 1,
- "name": "transpose/perm",
+ "name": "outputTensor",
"quantization": {
"details_type": 0,
"quantized_dimension": 0
},
"is_variable": false
- }
- ],
+ })";
+ if (!permuteData.empty())
+ {
+ m_JsonString += R"(,
+ {
+ "shape": [
+ 3
+ ],
+ "type": "INT32",
+ "buffer": 2,
+ "name": "permuteTensor",
+ "quantization": {
+ "details_type": 0,
+ "quantized_dimension": 0
+ },
+ "is_variable": false
+ })";
+ }
+
+ m_JsonString += R"(],
"inputs": [
0
],
@@ -78,9 +84,12 @@ struct TransposeFixture : public ParserFlatbuffersFixture
{
"opcode_index": 0,
"inputs": [
- 0,
- 2
- ],
+ 0)";
+ if (!permuteData.empty())
+ {
+ m_JsonString += R"(,2)";
+ }
+ m_JsonString += R"(],
"outputs": [
1
],
@@ -95,9 +104,12 @@ struct TransposeFixture : public ParserFlatbuffersFixture
"description": "TOCO Converted.",
"buffers": [
{ },
- { },
- { },
- { }
+ { })";
+ if (!permuteData.empty())
+ {
+ m_JsonString += R"(,{"data": )" + permuteData + R"( })";
+ }
+ m_JsonString += R"(
]
}
)";
@@ -105,20 +117,39 @@ struct TransposeFixture : public ParserFlatbuffersFixture
}
};
-struct SimpleTransposeFixture : TransposeFixture
+struct TransposeFixtureWithPermuteData : TransposeFixture
{
- SimpleTransposeFixture() : TransposeFixture("[ 2, 2, 3 ]",
- "[ 2, 3, 2 ]") {}
+ TransposeFixtureWithPermuteData() : TransposeFixture("[ 2, 2, 3 ]",
+ "[ 0, 0, 0, 0, 2, 0, 0, 0, 1, 0, 0, 0 ]",
+ "[ 2, 3, 2 ]") {}
};
-BOOST_FIXTURE_TEST_CASE(SimpleTranspose, SimpleTransposeFixture)
+BOOST_FIXTURE_TEST_CASE(TransposeWithPermuteData, TransposeFixtureWithPermuteData)
{
RunTest<3, armnn::DataType::Float32>(
0,
- {{"Placeholder", { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12 }}},
+ {{"inputTensor", { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12 }}},
+ {{"outputTensor", { 1, 4, 2, 5, 3, 6, 7, 10, 8, 11, 9, 12 }}});
+
+ BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo(0, "outputTensor").second.GetShape()
+ == armnn::TensorShape({2,3,2})));
+}
+
+struct TransposeFixtureWithoutPermuteData : TransposeFixture
+{
+ TransposeFixtureWithoutPermuteData() : TransposeFixture("[ 2, 2, 3 ]",
+ "",
+ "[ 2, 3, 2 ]") {}
+};
+
+BOOST_FIXTURE_TEST_CASE(TransposeWithoutPermuteDims, TransposeFixtureWithoutPermuteData)
+{
+ RunTest<3, armnn::DataType::Float32>(
+ 0,
+ {{"inputTensor", { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12 }}},
+ {{"outputTensor", { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12 }}});
- {{"transpose", { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12 }}});
- BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo(0, "transpose").second.GetShape()
+ BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo(0, "outputTensor").second.GetShape()
== armnn::TensorShape({2,3,2})));
}