diff options
Diffstat (limited to 'src/armnnTfParser/test/Transpose.cpp')
-rw-r--r-- | src/armnnTfParser/test/Transpose.cpp | 151 |
1 files changed, 0 insertions, 151 deletions
diff --git a/src/armnnTfParser/test/Transpose.cpp b/src/armnnTfParser/test/Transpose.cpp deleted file mode 100644 index dd73bd90a2..0000000000 --- a/src/armnnTfParser/test/Transpose.cpp +++ /dev/null @@ -1,151 +0,0 @@ -// -// Copyright © 2020 Arm Ltd. All rights reserved. -// SPDX-License-Identifier: MIT -// - -#include "armnnTfParser/ITfParser.hpp" -#include "ParserPrototxtFixture.hpp" - -#include <boost/test/unit_test.hpp> -#include <PrototxtConversions.hpp> - -BOOST_AUTO_TEST_SUITE(TensorflowParser) - -namespace -{ - std::string ConvertInt32VectorToOctalString(const std::vector<unsigned int>& data) - { - std::stringstream ss; - ss << "\""; - std::for_each(data.begin(), data.end(), [&ss](unsigned int d) { - ss << armnnUtils::ConvertInt32ToOctalString(static_cast<int>(d)); - }); - ss << "\""; - return ss.str(); - } -} // namespace - -struct TransposeFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser> -{ - TransposeFixture(const armnn::TensorShape& inputShape, - const std::vector<unsigned int>& permuteVectorData) - { - using armnnUtils::ConvertTensorShapeToString; - armnn::TensorShape permuteVectorShape({static_cast<unsigned int>(permuteVectorData.size())}); - - m_Prototext = "node {\n" -" name: \"input\"\n" -" op: \"Placeholder\"\n" -" attr {\n" -" key: \"dtype\"\n" -" value {\n" -" type: DT_FLOAT\n" -" }\n" -" }\n" -" attr {\n" -" key: \"shape\"\n" -" value {\n" -" shape {\n"; - m_Prototext.append(ConvertTensorShapeToString(inputShape)); - m_Prototext.append( -" }\n" -" }\n" -" }\n" -"}\n" -"node {\n" -" name: \"transpose/perm\"\n" -" op: \"Const\"\n" -" attr {\n" -" key: \"dtype\"\n" -" value {\n" -" type: DT_INT32\n" -" }\n" -" }\n" -" attr {\n" -" key: \"value\"\n" -" value {\n" -" tensor {\n" -" dtype: DT_INT32\n" -" tensor_shape {\n" - ); - m_Prototext.append(ConvertTensorShapeToString(permuteVectorShape)); - m_Prototext.append( -" }\n" -" tensor_content: " - ); - m_Prototext.append(ConvertInt32VectorToOctalString(permuteVectorData) + "\n"); - m_Prototext.append( -" }\n" -" }\n" -" }\n" -"}\n" - ); - m_Prototext.append( -"node {\n" -" name: \"output\"\n" -" op: \"Transpose\"\n" -" input: \"input\"\n" -" input: \"transpose/perm\"\n" -" attr {\n" -" key: \"T\"\n" -" value {\n" -" type: DT_FLOAT\n" -" }\n" -" }\n" -" attr {\n" -" key: \"Tperm\"\n" -" value {\n" -" type: DT_INT32\n" -" }\n" -" }\n" -"}\n" - ); - Setup({{"input", inputShape}}, {"output"}); - } -}; - -struct TransposeFixtureWithPermuteData : TransposeFixture -{ - TransposeFixtureWithPermuteData() - : TransposeFixture({2, 2, 3, 4}, - std::vector<unsigned int>({1, 3, 2, 0})) {} -}; - -BOOST_FIXTURE_TEST_CASE(TransposeWithPermuteData, TransposeFixtureWithPermuteData) -{ - RunTest<4>( - {{"input", {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, - 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, - 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47}}}, - {{"output", {0, 24, 4, 28, 8, 32, 1, 25, 5, 29, 9, 33, 2, 26, 6, - 30, 10, 34, 3, 27, 7, 31, 11, 35, 12, 36, 16, 40, 20, 44, 13, 37, - 17, 41, 21, 45, 14, 38, 18, 42, 22, 46, 15, 39, 19, 43, 23, 47}}}); - - BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo("output").second.GetShape() - == armnn::TensorShape({2, 4, 3, 2}))); -} - -struct TransposeFixtureWithoutPermuteData : TransposeFixture -{ - // In case permute data is not given, it assumes (n-1,...,0) is given - // where n is the rank of input tensor. - TransposeFixtureWithoutPermuteData() - : TransposeFixture({2, 2, 3, 4}, - std::vector<unsigned int>({3, 2, 1, 0})) {} -}; - -BOOST_FIXTURE_TEST_CASE(TransposeWithoutPermuteData, TransposeFixtureWithoutPermuteData) -{ - RunTest<4>( - {{"input", {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, - 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, - 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47}}}, - {{"output", {0, 24, 12, 36, 4, 28, 16, 40, 8, 32, 20, 44, 1, 25, - 13, 37, 5, 29, 17, 41, 9, 33, 21, 45, 2, 26, 14, 38, 6, 30, 18, - 42,10, 34, 22, 46, 3, 27, 15, 39, 7, 31, 19, 43, 11, 35, 23, 47}}}); - - BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo("output").second.GetShape() - == armnn::TensorShape({4, 3, 2, 2}))); -} - -BOOST_AUTO_TEST_SUITE_END() |