aboutsummaryrefslogtreecommitdiff
path: root/src/armnnTfLiteParser/test
diff options
context:
space:
mode:
authorKevin May <kevin.may@arm.com>2019-09-27 17:21:06 +0100
committerKevin May <kevin.may@arm.com>2019-09-27 17:21:06 +0100
commit85d9260b769bdad8ffde37546837cc206ac8ee14 (patch)
tree1d38d73821f6ab84331e1471d79c6f462f50dcc3 /src/armnnTfLiteParser/test
parentdfa1477f8144c8ae933f949f0cc6ab70b6ba372d (diff)
downloadarmnn-85d9260b769bdad8ffde37546837cc206ac8ee14.tar.gz
IVGCVSW-3909 Fix Transpose perm vector not parsed by Tflite parser
* Add permute vector to descriptor if present * Refactor test to check with and without permute vector Signed-off-by: Kevin May <kevin.may@arm.com> Change-Id: Ic8d882bb0f982fd00bb2854c18ea316b1b2cde2b
Diffstat (limited to 'src/armnnTfLiteParser/test')
-rw-r--r--src/armnnTfLiteParser/test/Transpose.cpp93
1 files changed, 62 insertions, 31 deletions
diff --git a/src/armnnTfLiteParser/test/Transpose.cpp b/src/armnnTfLiteParser/test/Transpose.cpp
index 4430438bb9..2e3190b62e 100644
--- a/src/armnnTfLiteParser/test/Transpose.cpp
+++ b/src/armnnTfLiteParser/test/Transpose.cpp
@@ -12,6 +12,7 @@ BOOST_AUTO_TEST_SUITE(TensorflowLiteParser)
struct TransposeFixture : public ParserFlatbuffersFixture
{
explicit TransposeFixture(const std::string & inputShape,
+ const std::string & permuteData,
const std::string & outputShape)
{
m_JsonString = R"(
@@ -29,8 +30,8 @@ struct TransposeFixture : public ParserFlatbuffersFixture
{
"shape": )" + inputShape + R"(,
"type": "FLOAT32",
- "buffer": 3,
- "name": "Placeholder",
+ "buffer": 0,
+ "name": "inputTensor",
"quantization": {
"min": [
0.0
@@ -46,28 +47,33 @@ struct TransposeFixture : public ParserFlatbuffersFixture
{
"shape": )" + outputShape + R"(,
"type": "FLOAT32",
- "buffer": 2,
- "name": "transpose",
- "quantization": {
- "details_type": 0,
- "quantized_dimension": 0
- },
- "is_variable": false
- },
- {
- "shape": [
- 3
- ],
- "type": "INT32",
"buffer": 1,
- "name": "transpose/perm",
+ "name": "outputTensor",
"quantization": {
"details_type": 0,
"quantized_dimension": 0
},
"is_variable": false
- }
- ],
+ })";
+ if (!permuteData.empty())
+ {
+ m_JsonString += R"(,
+ {
+ "shape": [
+ 3
+ ],
+ "type": "INT32",
+ "buffer": 2,
+ "name": "permuteTensor",
+ "quantization": {
+ "details_type": 0,
+ "quantized_dimension": 0
+ },
+ "is_variable": false
+ })";
+ }
+
+ m_JsonString += R"(],
"inputs": [
0
],
@@ -78,9 +84,12 @@ struct TransposeFixture : public ParserFlatbuffersFixture
{
"opcode_index": 0,
"inputs": [
- 0,
- 2
- ],
+ 0)";
+ if (!permuteData.empty())
+ {
+ m_JsonString += R"(,2)";
+ }
+ m_JsonString += R"(],
"outputs": [
1
],
@@ -95,9 +104,12 @@ struct TransposeFixture : public ParserFlatbuffersFixture
"description": "TOCO Converted.",
"buffers": [
{ },
- { },
- { },
- { }
+ { })";
+ if (!permuteData.empty())
+ {
+ m_JsonString += R"(,{"data": )" + permuteData + R"( })";
+ }
+ m_JsonString += R"(
]
}
)";
@@ -105,20 +117,39 @@ struct TransposeFixture : public ParserFlatbuffersFixture
}
};
-struct SimpleTransposeFixture : TransposeFixture
+struct TransposeFixtureWithPermuteData : TransposeFixture
{
- SimpleTransposeFixture() : TransposeFixture("[ 2, 2, 3 ]",
- "[ 2, 3, 2 ]") {}
+ TransposeFixtureWithPermuteData() : TransposeFixture("[ 2, 2, 3 ]",
+ "[ 0, 0, 0, 0, 2, 0, 0, 0, 1, 0, 0, 0 ]",
+ "[ 2, 3, 2 ]") {}
};
-BOOST_FIXTURE_TEST_CASE(SimpleTranspose, SimpleTransposeFixture)
+BOOST_FIXTURE_TEST_CASE(TransposeWithPermuteData, TransposeFixtureWithPermuteData)
{
RunTest<3, armnn::DataType::Float32>(
0,
- {{"Placeholder", { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12 }}},
+ {{"inputTensor", { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12 }}},
+ {{"outputTensor", { 1, 4, 2, 5, 3, 6, 7, 10, 8, 11, 9, 12 }}});
+
+ BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo(0, "outputTensor").second.GetShape()
+ == armnn::TensorShape({2,3,2})));
+}
+
+struct TransposeFixtureWithoutPermuteData : TransposeFixture
+{
+ TransposeFixtureWithoutPermuteData() : TransposeFixture("[ 2, 2, 3 ]",
+ "",
+ "[ 2, 3, 2 ]") {}
+};
+
+BOOST_FIXTURE_TEST_CASE(TransposeWithoutPermuteDims, TransposeFixtureWithoutPermuteData)
+{
+ RunTest<3, armnn::DataType::Float32>(
+ 0,
+ {{"inputTensor", { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12 }}},
+ {{"outputTensor", { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12 }}});
- {{"transpose", { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12 }}});
- BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo(0, "transpose").second.GetShape()
+ BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo(0, "outputTensor").second.GetShape()
== armnn::TensorShape({2,3,2})));
}