aboutsummaryrefslogtreecommitdiff
path: root/src/armnnTfLiteParser/test/Transpose.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnTfLiteParser/test/Transpose.cpp')
-rw-r--r--src/armnnTfLiteParser/test/Transpose.cpp55
1 files changed, 26 insertions, 29 deletions
diff --git a/src/armnnTfLiteParser/test/Transpose.cpp b/src/armnnTfLiteParser/test/Transpose.cpp
index 2e3190b62e..b2f953e75d 100644
--- a/src/armnnTfLiteParser/test/Transpose.cpp
+++ b/src/armnnTfLiteParser/test/Transpose.cpp
@@ -55,24 +55,20 @@ struct TransposeFixture : public ParserFlatbuffersFixture
},
"is_variable": false
})";
- if (!permuteData.empty())
- {
- m_JsonString += R"(,
- {
- "shape": [
- 3
- ],
- "type": "INT32",
- "buffer": 2,
- "name": "permuteTensor",
- "quantization": {
- "details_type": 0,
- "quantized_dimension": 0
- },
- "is_variable": false
- })";
- }
-
+ m_JsonString += R"(,
+ {
+ "shape": [
+ 3
+ ],
+ "type": "INT32",
+ "buffer": 2,
+ "name": "permuteTensor",
+ "quantization": {
+ "details_type": 0,
+ "quantized_dimension": 0
+ },
+ "is_variable": false
+ })";
m_JsonString += R"(],
"inputs": [
0
@@ -85,10 +81,7 @@ struct TransposeFixture : public ParserFlatbuffersFixture
"opcode_index": 0,
"inputs": [
0)";
- if (!permuteData.empty())
- {
- m_JsonString += R"(,2)";
- }
+ m_JsonString += R"(,2)";
m_JsonString += R"(],
"outputs": [
1
@@ -117,6 +110,7 @@ struct TransposeFixture : public ParserFlatbuffersFixture
}
};
+// Note that this assumes the Tensorflow permutation vector implementation as opposed to the armnn implemenation.
struct TransposeFixtureWithPermuteData : TransposeFixture
{
TransposeFixtureWithPermuteData() : TransposeFixture("[ 2, 2, 3 ]",
@@ -128,29 +122,32 @@ BOOST_FIXTURE_TEST_CASE(TransposeWithPermuteData, TransposeFixtureWithPermuteDat
{
RunTest<3, armnn::DataType::Float32>(
0,
- {{"inputTensor", { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12 }}},
- {{"outputTensor", { 1, 4, 2, 5, 3, 6, 7, 10, 8, 11, 9, 12 }}});
+ {{"inputTensor", { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12 }}},
+ {{"outputTensor", { 1, 4, 2, 5, 3, 6, 7, 10, 8, 11, 9, 12 }}});
BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo(0, "outputTensor").second.GetShape()
== armnn::TensorShape({2,3,2})));
}
+// Tensorflow default permutation behavior assumes no permute argument will create permute vector [n-1...0],
+// where n is the number of dimensions of the input tensor
+// In this case we should get output shape 3,2,2 given default permutation vector 2,1,0
struct TransposeFixtureWithoutPermuteData : TransposeFixture
{
TransposeFixtureWithoutPermuteData() : TransposeFixture("[ 2, 2, 3 ]",
- "",
- "[ 2, 3, 2 ]") {}
+ "[ 2, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0 ]",
+ "[ 3, 2, 2 ]") {}
};
BOOST_FIXTURE_TEST_CASE(TransposeWithoutPermuteDims, TransposeFixtureWithoutPermuteData)
{
RunTest<3, armnn::DataType::Float32>(
0,
- {{"inputTensor", { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12 }}},
- {{"outputTensor", { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12 }}});
+ {{"inputTensor", { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12 }}},
+ {{"outputTensor", { 1, 7, 4, 10, 2, 8, 5, 11, 3, 9, 6, 12 }}});
BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo(0, "outputTensor").second.GetShape()
- == armnn::TensorShape({2,3,2})));
+ == armnn::TensorShape({3,2,2})));
}
BOOST_AUTO_TEST_SUITE_END() \ No newline at end of file