diff options
Diffstat (limited to 'src/armnnTfLiteParser/test/Transpose.cpp')
-rw-r--r-- | src/armnnTfLiteParser/test/Transpose.cpp | 55 |
1 files changed, 26 insertions, 29 deletions
diff --git a/src/armnnTfLiteParser/test/Transpose.cpp b/src/armnnTfLiteParser/test/Transpose.cpp index 2e3190b62e..b2f953e75d 100644 --- a/src/armnnTfLiteParser/test/Transpose.cpp +++ b/src/armnnTfLiteParser/test/Transpose.cpp @@ -55,24 +55,20 @@ struct TransposeFixture : public ParserFlatbuffersFixture }, "is_variable": false })"; - if (!permuteData.empty()) - { - m_JsonString += R"(, - { - "shape": [ - 3 - ], - "type": "INT32", - "buffer": 2, - "name": "permuteTensor", - "quantization": { - "details_type": 0, - "quantized_dimension": 0 - }, - "is_variable": false - })"; - } - + m_JsonString += R"(, + { + "shape": [ + 3 + ], + "type": "INT32", + "buffer": 2, + "name": "permuteTensor", + "quantization": { + "details_type": 0, + "quantized_dimension": 0 + }, + "is_variable": false + })"; m_JsonString += R"(], "inputs": [ 0 @@ -85,10 +81,7 @@ struct TransposeFixture : public ParserFlatbuffersFixture "opcode_index": 0, "inputs": [ 0)"; - if (!permuteData.empty()) - { - m_JsonString += R"(,2)"; - } + m_JsonString += R"(,2)"; m_JsonString += R"(], "outputs": [ 1 @@ -117,6 +110,7 @@ struct TransposeFixture : public ParserFlatbuffersFixture } }; +// Note that this assumes the Tensorflow permutation vector implementation as opposed to the armnn implemenation. struct TransposeFixtureWithPermuteData : TransposeFixture { TransposeFixtureWithPermuteData() : TransposeFixture("[ 2, 2, 3 ]", @@ -128,29 +122,32 @@ BOOST_FIXTURE_TEST_CASE(TransposeWithPermuteData, TransposeFixtureWithPermuteDat { RunTest<3, armnn::DataType::Float32>( 0, - {{"inputTensor", { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12 }}}, - {{"outputTensor", { 1, 4, 2, 5, 3, 6, 7, 10, 8, 11, 9, 12 }}}); + {{"inputTensor", { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12 }}}, + {{"outputTensor", { 1, 4, 2, 5, 3, 6, 7, 10, 8, 11, 9, 12 }}}); BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo(0, "outputTensor").second.GetShape() == armnn::TensorShape({2,3,2}))); } +// Tensorflow default permutation behavior assumes no permute argument will create permute vector [n-1...0], +// where n is the number of dimensions of the input tensor +// In this case we should get output shape 3,2,2 given default permutation vector 2,1,0 struct TransposeFixtureWithoutPermuteData : TransposeFixture { TransposeFixtureWithoutPermuteData() : TransposeFixture("[ 2, 2, 3 ]", - "", - "[ 2, 3, 2 ]") {} + "[ 2, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0 ]", + "[ 3, 2, 2 ]") {} }; BOOST_FIXTURE_TEST_CASE(TransposeWithoutPermuteDims, TransposeFixtureWithoutPermuteData) { RunTest<3, armnn::DataType::Float32>( 0, - {{"inputTensor", { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12 }}}, - {{"outputTensor", { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12 }}}); + {{"inputTensor", { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12 }}}, + {{"outputTensor", { 1, 7, 4, 10, 2, 8, 5, 11, 3, 9, 6, 12 }}}); BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo(0, "outputTensor").second.GetShape() - == armnn::TensorShape({2,3,2}))); + == armnn::TensorShape({3,2,2}))); } BOOST_AUTO_TEST_SUITE_END()
\ No newline at end of file |