aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSang-Hoon Park <sang-hoon.park@arm.com>2020-02-18 11:27:35 +0000
committerSang-Hoon Park <sang-hoon.park@arm.com>2020-02-24 13:29:57 +0000
commitdd3f71b64072c44cec65a7a883d0c3a29659645c (patch)
treed48b0508811e88d47b4b95a1777742312b2ba4db
parentcf34f510b891b87b4c7be5edb30272b36cab7b51 (diff)
downloadarmnn-dd3f71b64072c44cec65a7a883d0c3a29659645c.tar.gz
COMPMID-3060: Add TF Parser support for Transpose
Signed-off-by: Sang-Hoon Park <sang-hoon.park@arm.com> Change-Id: I9661787071554b38c5b0ab3c98431f3863b98520
-rw-r--r--CMakeLists.txt1
-rwxr-xr-xsrc/armnnTfParser/TfParser.cpp57
-rw-r--r--src/armnnTfParser/TfParser.hpp1
-rw-r--r--src/armnnTfParser/test/Transpose.cpp151
-rw-r--r--src/armnnUtils/PrototxtConversions.cpp14
-rw-r--r--src/armnnUtils/PrototxtConversions.hpp8
-rw-r--r--src/armnnUtils/test/PrototxtConversionsTest.cpp97
7 files changed, 328 insertions, 1 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt
index e5876db926..abdcc37c4f 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -713,6 +713,7 @@ if(BUILD_UNIT_TESTS)
src/armnnTfParser/test/Stack.cpp
src/armnnTfParser/test/Sub.cpp
src/armnnTfParser/test/StridedSlice.cpp
+ src/armnnTfParser/test/Transpose.cpp
)
endif()
diff --git a/src/armnnTfParser/TfParser.cpp b/src/armnnTfParser/TfParser.cpp
index b5a421145a..124c5fdcc7 100755
--- a/src/armnnTfParser/TfParser.cpp
+++ b/src/armnnTfParser/TfParser.cpp
@@ -378,7 +378,8 @@ const std::map<std::string, TfParser::OperationParsingFunction> TfParser::ms_Ope
{ "Pad", &TfParser::ParsePad },
{ "Sub", &TfParser::ParseSub },
{ "Pack" , &TfParser::ParseStack },
- { "Stack", &TfParser::ParseStack }
+ { "Stack", &TfParser::ParseStack },
+ { "Transpose", &TfParser::ParseTranspose },
};
const std::list<std::string> TfParser::m_ControlInputs = {
@@ -2054,6 +2055,60 @@ ParsedTfOperationPtr TfParser::ParseStack(const tensorflow::NodeDef& nodeDef, co
return std::make_unique<SingleLayerParsedTfOperation>(this, nodeDef, layer);
}
+ParsedTfOperationPtr TfParser::ParseTranspose(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef)
+{
+ boost::ignore_unused(graphDef);
+
+ auto inputs = GetInputParsedTfOperationsChecked(nodeDef, 2);
+ const auto inputCount = inputs.size();
+
+ if (inputCount != 2)
+ {
+ throw ParseException(
+ boost::str(
+ boost::format(
+ "The number of given input is %1%. It should be two for Transpose op."
+ "Node %2% %3%")
+ % inputCount
+ % nodeDef.name()
+ % CHECK_LOCATION().AsString()));
+ }
+
+ auto* input0Slot = &inputs[0].m_IndexedValue->ResolveArmnnOutputSlot(inputs[0].m_Index);
+
+ const auto constInput = inputs[GetConstInputIndex(inputs)];
+ auto* permuteVectorInput =
+ boost::polymorphic_downcast<ParsedConstTfOperation<int32_t>*>(constInput.m_IndexedValue);
+ const auto& permuteVectorInfo = permuteVectorInput->GetTensorInfo();
+
+ std::vector<int32_t> permuteVectorData;
+ permuteVectorInput->GetConstTensor(permuteVectorData);
+
+ std::vector<unsigned int> armnnPermuteVectorData(permuteVectorData.size());
+ std::vector<int32_t>::iterator it;
+
+ for (unsigned int i = 0u; i < permuteVectorData.size(); ++i)
+ {
+ it = std::find(permuteVectorData.begin(), permuteVectorData.end(), i);
+ armnnPermuteVectorData[i] = static_cast<unsigned int>(std::distance(permuteVectorData.begin(), it));
+ }
+
+ const auto permutationVector = PermutationVector(armnnPermuteVectorData.data(), permuteVectorInfo.GetNumElements());
+ const auto desc = PermuteDescriptor(permutationVector);
+
+ auto* layer = m_Network->AddPermuteLayer(desc, nodeDef.name().c_str());
+ BOOST_ASSERT(layer);
+
+ input0Slot->Connect(layer->GetInputSlot(0));
+
+ const auto& input0Info = input0Slot->GetTensorInfo();
+ armnn::TensorInfo outputInfo {input0Info};
+ outputInfo.SetShape(armnnUtils::Permuted(input0Info.GetShape(), desc.m_DimMappings));
+ layer->GetOutputSlot(0).SetTensorInfo(outputInfo);
+
+ return std::make_unique<SingleLayerParsedTfOperation>(this, nodeDef, layer);
+}
+
unsigned int CheckPaddingTensor(const ConstTensor& paddingTensor,
const TensorInfo& inputTensorInfo,
const std::string& nodeName)
diff --git a/src/armnnTfParser/TfParser.hpp b/src/armnnTfParser/TfParser.hpp
index 9277d44cb1..94499ea52d 100644
--- a/src/armnnTfParser/TfParser.hpp
+++ b/src/armnnTfParser/TfParser.hpp
@@ -171,6 +171,7 @@ private:
ParsedTfOperationPtr ParsePad(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
ParsedTfOperationPtr ParseSub(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
ParsedTfOperationPtr ParseStack(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
+ ParsedTfOperationPtr ParseTranspose(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
ParsedTfOperationPtr AddActivationLayer(const tensorflow::NodeDef& nodeDef, armnn::ActivationDescriptor& desc);
ParsedTfOperationPtr AddAdditionLayer(const tensorflow::NodeDef& nodeDef, bool isBiasAdd = false);
ParsedTfOperationPtr AddRealDivLayer(const tensorflow::NodeDef& nodeDef);
diff --git a/src/armnnTfParser/test/Transpose.cpp b/src/armnnTfParser/test/Transpose.cpp
new file mode 100644
index 0000000000..dd73bd90a2
--- /dev/null
+++ b/src/armnnTfParser/test/Transpose.cpp
@@ -0,0 +1,151 @@
+//
+// Copyright © 2020 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "armnnTfParser/ITfParser.hpp"
+#include "ParserPrototxtFixture.hpp"
+
+#include <boost/test/unit_test.hpp>
+#include <PrototxtConversions.hpp>
+
+BOOST_AUTO_TEST_SUITE(TensorflowParser)
+
+namespace
+{
+ std::string ConvertInt32VectorToOctalString(const std::vector<unsigned int>& data)
+ {
+ std::stringstream ss;
+ ss << "\"";
+ std::for_each(data.begin(), data.end(), [&ss](unsigned int d) {
+ ss << armnnUtils::ConvertInt32ToOctalString(static_cast<int>(d));
+ });
+ ss << "\"";
+ return ss.str();
+ }
+} // namespace
+
+struct TransposeFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
+{
+ TransposeFixture(const armnn::TensorShape& inputShape,
+ const std::vector<unsigned int>& permuteVectorData)
+ {
+ using armnnUtils::ConvertTensorShapeToString;
+ armnn::TensorShape permuteVectorShape({static_cast<unsigned int>(permuteVectorData.size())});
+
+ m_Prototext = "node {\n"
+" name: \"input\"\n"
+" op: \"Placeholder\"\n"
+" attr {\n"
+" key: \"dtype\"\n"
+" value {\n"
+" type: DT_FLOAT\n"
+" }\n"
+" }\n"
+" attr {\n"
+" key: \"shape\"\n"
+" value {\n"
+" shape {\n";
+ m_Prototext.append(ConvertTensorShapeToString(inputShape));
+ m_Prototext.append(
+" }\n"
+" }\n"
+" }\n"
+"}\n"
+"node {\n"
+" name: \"transpose/perm\"\n"
+" op: \"Const\"\n"
+" attr {\n"
+" key: \"dtype\"\n"
+" value {\n"
+" type: DT_INT32\n"
+" }\n"
+" }\n"
+" attr {\n"
+" key: \"value\"\n"
+" value {\n"
+" tensor {\n"
+" dtype: DT_INT32\n"
+" tensor_shape {\n"
+ );
+ m_Prototext.append(ConvertTensorShapeToString(permuteVectorShape));
+ m_Prototext.append(
+" }\n"
+" tensor_content: "
+ );
+ m_Prototext.append(ConvertInt32VectorToOctalString(permuteVectorData) + "\n");
+ m_Prototext.append(
+" }\n"
+" }\n"
+" }\n"
+"}\n"
+ );
+ m_Prototext.append(
+"node {\n"
+" name: \"output\"\n"
+" op: \"Transpose\"\n"
+" input: \"input\"\n"
+" input: \"transpose/perm\"\n"
+" attr {\n"
+" key: \"T\"\n"
+" value {\n"
+" type: DT_FLOAT\n"
+" }\n"
+" }\n"
+" attr {\n"
+" key: \"Tperm\"\n"
+" value {\n"
+" type: DT_INT32\n"
+" }\n"
+" }\n"
+"}\n"
+ );
+ Setup({{"input", inputShape}}, {"output"});
+ }
+};
+
+struct TransposeFixtureWithPermuteData : TransposeFixture
+{
+ TransposeFixtureWithPermuteData()
+ : TransposeFixture({2, 2, 3, 4},
+ std::vector<unsigned int>({1, 3, 2, 0})) {}
+};
+
+BOOST_FIXTURE_TEST_CASE(TransposeWithPermuteData, TransposeFixtureWithPermuteData)
+{
+ RunTest<4>(
+ {{"input", {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
+ 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
+ 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47}}},
+ {{"output", {0, 24, 4, 28, 8, 32, 1, 25, 5, 29, 9, 33, 2, 26, 6,
+ 30, 10, 34, 3, 27, 7, 31, 11, 35, 12, 36, 16, 40, 20, 44, 13, 37,
+ 17, 41, 21, 45, 14, 38, 18, 42, 22, 46, 15, 39, 19, 43, 23, 47}}});
+
+ BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo("output").second.GetShape()
+ == armnn::TensorShape({2, 4, 3, 2})));
+}
+
+struct TransposeFixtureWithoutPermuteData : TransposeFixture
+{
+ // In case permute data is not given, it assumes (n-1,...,0) is given
+ // where n is the rank of input tensor.
+ TransposeFixtureWithoutPermuteData()
+ : TransposeFixture({2, 2, 3, 4},
+ std::vector<unsigned int>({3, 2, 1, 0})) {}
+};
+
+BOOST_FIXTURE_TEST_CASE(TransposeWithoutPermuteData, TransposeFixtureWithoutPermuteData)
+{
+ RunTest<4>(
+ {{"input", {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
+ 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
+ 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47}}},
+ {{"output", {0, 24, 12, 36, 4, 28, 16, 40, 8, 32, 20, 44, 1, 25,
+ 13, 37, 5, 29, 17, 41, 9, 33, 21, 45, 2, 26, 14, 38, 6, 30, 18,
+ 42,10, 34, 22, 46, 3, 27, 15, 39, 7, 31, 19, 43, 11, 35, 23, 47}}});
+
+ BOOST_TEST((m_Parser->GetNetworkOutputBindingInfo("output").second.GetShape()
+ == armnn::TensorShape({4, 3, 2, 2})));
+}
+
+BOOST_AUTO_TEST_SUITE_END()
diff --git a/src/armnnUtils/PrototxtConversions.cpp b/src/armnnUtils/PrototxtConversions.cpp
index 1a6c053355..cc86321a27 100644
--- a/src/armnnUtils/PrototxtConversions.cpp
+++ b/src/armnnUtils/PrototxtConversions.cpp
@@ -4,6 +4,7 @@
//
#include "PrototxtConversions.hpp"
+#include "armnn/Tensor.hpp"
#include <boost/format.hpp>
@@ -29,4 +30,17 @@ std::string ConvertInt32ToOctalString(int value)
return returnString;
}
+/// Converts an TensorShape into Prototxt representation
+std::string ConvertTensorShapeToString(const armnn::TensorShape& shape)
+{
+ std::stringstream ss;
+ for (unsigned int i = 0 ; i < shape.GetNumDimensions() ; i++)
+ {
+ ss << "dim {\n";
+ ss << "size: " << std::to_string(shape[i]) << "\n";
+ ss << "}\n";
+ }
+ return ss.str();
+
+}
} // namespace armnnUtils
diff --git a/src/armnnUtils/PrototxtConversions.hpp b/src/armnnUtils/PrototxtConversions.hpp
index c90af9edff..fb066729fe 100644
--- a/src/armnnUtils/PrototxtConversions.hpp
+++ b/src/armnnUtils/PrototxtConversions.hpp
@@ -7,10 +7,18 @@
#include <string>
+namespace armnn
+{
+class TensorShape;
+} // namespace armnn
+
namespace armnnUtils
{
/// Converts an int value into the Prototxt octal representation
std::string ConvertInt32ToOctalString(int value);
+/// Converts an TensorShape into Prototxt representation
+std::string ConvertTensorShapeToString(const armnn::TensorShape& shape);
+
} // namespace armnnUtils
diff --git a/src/armnnUtils/test/PrototxtConversionsTest.cpp b/src/armnnUtils/test/PrototxtConversionsTest.cpp
index e06fbe0f2e..f263a52340 100644
--- a/src/armnnUtils/test/PrototxtConversionsTest.cpp
+++ b/src/armnnUtils/test/PrototxtConversionsTest.cpp
@@ -4,6 +4,7 @@
//
#include <PrototxtConversions.hpp>
+#include "armnn/Tensor.hpp"
#include <boost/test/unit_test.hpp>
@@ -38,4 +39,100 @@ BOOST_AUTO_TEST_CASE(ConvertInt32ToOctalStringTest)
BOOST_ASSERT(octalString.compare("\\\\000\\\\000\\\\000\\\\377"));
}
+BOOST_AUTO_TEST_CASE(ConvertTensorShapeToStringTest)
+{
+ using armnnUtils::ConvertTensorShapeToString;
+ using armnn::TensorShape;
+
+ auto createAndConvert = [](std::initializer_list<unsigned int> dims) -> std::string
+ {
+ auto shape = TensorShape{dims};
+ return ConvertTensorShapeToString(shape);
+ };
+
+ auto output_string = createAndConvert({5});
+ BOOST_ASSERT(output_string.compare(
+ "dim {\n"
+ "size: 5\n"
+ "}"));
+
+ output_string = createAndConvert({4, 5});
+ BOOST_ASSERT(output_string.compare(
+ "dim {\n"
+ "size: 4\n"
+ "}\n"
+ "dim {\n"
+ "size: 5\n"
+ "}"
+ ));
+
+ output_string = createAndConvert({3, 4, 5});
+ BOOST_ASSERT(output_string.compare(
+ "dim {\n"
+ "size: 3\n"
+ "}\n"
+ "dim {\n"
+ "size: 4\n"
+ "}\n"
+ "dim {\n"
+ "size: 5\n"
+ "}"
+ ));
+
+ output_string = createAndConvert({2, 3, 4, 5});
+ BOOST_ASSERT(output_string.compare(
+ "dim {\n"
+ "size: 2\n"
+ "}\n"
+ "dim {\n"
+ "size: 3\n"
+ "}\n"
+ "dim {\n"
+ "size: 4\n"
+ "}\n"
+ "dim {\n"
+ "size: 5\n"
+ "}"
+ ));
+
+ output_string = createAndConvert({1, 2, 3, 4, 5});
+ BOOST_ASSERT(output_string.compare(
+ "dim {\n"
+ "size: 1\n"
+ "}\n"
+ "dim {\n"
+ "size: 2\n"
+ "}\n"
+ "dim {\n"
+ "size: 3\n"
+ "}\n"
+ "dim {\n"
+ "size: 4\n"
+ "}\n"
+ "dim {\n"
+ "size: 5\n"
+ "}"
+ ));
+
+ output_string = createAndConvert({0xffffffff, 0xffffffff});
+ BOOST_ASSERT(output_string.compare(
+ "dim {\n"
+ "size: 4294967295\n"
+ "}\n"
+ "dim {\n"
+ "size: 4294967295\n"
+ "}"
+ ));
+
+ output_string = createAndConvert({1, 0});
+ BOOST_ASSERT(output_string.compare(
+ "dim {\n"
+ "size: 1\n"
+ "}\n"
+ "dim {\n"
+ "size: 0\n"
+ "}"
+ ));
+}
+
BOOST_AUTO_TEST_SUITE_END()