aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSaoirse Stewart <saoirse.stewart@arm.com>2019-02-27 11:07:57 +0000
committerSaoirse Stewart Arm <saoirse.stewart@arm.com>2019-02-27 13:16:23 +0000
commit91c0eff2ff4ff384e013cb69cac1e07e28b9e2b1 (patch)
tree9db521e088c92dd7b9d9e2e6dfd917559aab4745
parentdbfb8549d4aa80115a7049b3e94788fb7a474d9b (diff)
downloadarmnn-91c0eff2ff4ff384e013cb69cac1e07e28b9e2b1.tar.gz
IVGCVSW-2598 Fix for constant axis issue for Tensorflow Parser
Change-Id: I8b081012529aed8e434273259c5a5ef7dc3afff7 Signed-off-by: Finn Williams <finn.williams@arm.com> Signed-off-by: Saoirse Stewart <saoirse.stewart@arm.com>
-rwxr-xr-xsrc/armnnTfParser/TfParser.cpp104
-rw-r--r--src/armnnTfParser/TfParser.hpp2
-rw-r--r--src/armnnTfParser/test/Split.cpp226
3 files changed, 197 insertions, 135 deletions
diff --git a/src/armnnTfParser/TfParser.cpp b/src/armnnTfParser/TfParser.cpp
index 0410460059..1e304cbfd7 100755
--- a/src/armnnTfParser/TfParser.cpp
+++ b/src/armnnTfParser/TfParser.cpp
@@ -1158,6 +1158,23 @@ bool TfParser::HasParsedConstTensor(ParsedTfOperation* parsedTfOpPtr) const
return dynamic_cast<ParsedConstTfOperation<Type>*>(parsedTfOpPtr) != nullptr;
}
+unsigned int TfParser::GetConstInputIndex(const std::vector<OutputOfParsedTfOperation>& inputs)
+{
+ for (unsigned int i = 0; i < inputs.size(); i++)
+ {
+ if (HasParsedConstTensor<int32_t>(inputs[i].m_IndexedValue->GetNode().name()))
+ {
+ return i;
+ }
+ }
+ throw ParseException(
+ boost::str(
+ boost::format(
+ "ArmNN only supports operators with constant axis. %1%")
+ % CHECK_LOCATION().AsString()));
+
+}
+
ParsedTfOperationPtr TfParser::ParseConv2D(const tensorflow::NodeDef& nodeDef,
const tensorflow::GraphDef& graphDef)
{
@@ -2040,22 +2057,12 @@ ParsedTfOperationPtr TfParser::ParseConcat(const tensorflow::NodeDef& nodeDef,
std::vector<OutputOfParsedTfOperation> inputs = GetInputParsedTfOperationsChecked(nodeDef, numInputs);
- // The last input is the axis for concatenation.
- if (!HasParsedConstTensor<int32_t>(inputs[numInputs - 1].m_IndexedValue->GetNode().name()))
- {
- throw ParseException(
- boost::str(
- boost::format(
- "ArmNN only supports Concat with constant axis. "
- "Input %1%. Node %2% %3%")
- % inputs[numInputs - 1].m_IndexedValue->GetNode().name()
- % nodeDef.name()
- % CHECK_LOCATION().AsString()));
- }
+ // Constant tensor index
+ unsigned int index = GetConstInputIndex(inputs);
+ // Get the axis tensor data
ParsedConstTfOperation<int32_t>* shapeNode =
- boost::polymorphic_downcast<ParsedConstTfOperation<int32_t>*>(inputs[numInputs - 1].m_IndexedValue);
+ boost::polymorphic_downcast<ParsedConstTfOperation<int32_t>*>(inputs[index].m_IndexedValue);
- // Get the axis tensor data
std::vector<int32_t> axisTensorData;
shapeNode->GetConstTensor(axisTensorData);
@@ -2066,13 +2073,13 @@ ParsedTfOperationPtr TfParser::ParseConcat(const tensorflow::NodeDef& nodeDef,
if (concatDim == 0 || concatDim == 2)
{
throw ParseException(
- boost::str(
- boost::format(
+ boost::str(
+ boost::format(
"Dimension %1% for concatenation is not supported by Armnn. "
"Node %2% %3%")
- % concatDim
- % nodeDef.name()
- % CHECK_LOCATION().AsString()));
+ % concatDim
+ % nodeDef.name()
+ % CHECK_LOCATION().AsString()));
}
unsigned int numConcatViews = numInputs - 1;
@@ -2090,13 +2097,13 @@ ParsedTfOperationPtr TfParser::ParseConcat(const tensorflow::NodeDef& nodeDef,
if (inputTensorInfo.GetNumDimensions() != MaxNumOfTensorDimensions)
{
throw armnn::ParseException(
- boost::str(
- boost::format(
+ boost::str(
+ boost::format(
"The number of dimensions: %1% for input tensors of the "
"concatenation op should be %2% %3%")
- % inputTensorInfo.GetNumDimensions()
- % MaxNumOfTensorDimensions
- % CHECK_LOCATION().AsString()));
+ % inputTensorInfo.GetNumDimensions()
+ % MaxNumOfTensorDimensions
+ % CHECK_LOCATION().AsString()));
}
// Copy the input tensor shape to mergeDimSizes and initialize the view origin coordinates for the current input
@@ -2605,22 +2612,12 @@ ParsedTfOperationPtr TfParser::ParseSplit(const tensorflow::NodeDef& nodeDef,
unsigned int numInputs = static_cast<unsigned int>(nodes.size());
std::vector<OutputOfParsedTfOperation> inputs = GetInputParsedTfOperationsChecked(nodeDef, numInputs);
- // The last input is the axis for split operation.
- if (!HasParsedConstTensor<int32_t>(inputs[numInputs - 1].m_IndexedValue->GetNode().name()))
- {
- throw ParseException(
- boost::str(
- boost::format(
- "ArmNN only supports split with constant axis. "
- "Input %1%. Node %2% %3%")
- % inputs[numInputs - 1].m_IndexedValue->GetNode().name()
- % nodeDef.name()
- % CHECK_LOCATION().AsString()));
- }
+ // Constant tensor index
+ unsigned int index = GetConstInputIndex(inputs);
+ // Get the axis tensor data
ParsedConstTfOperation<int32_t>* shapeNode =
- boost::polymorphic_downcast<ParsedConstTfOperation<int32_t>*>(inputs[numInputs - 1].m_IndexedValue);
+ boost::polymorphic_downcast<ParsedConstTfOperation<int32_t>*>(inputs[index].m_IndexedValue);
- // Get the axis tensor data
std::vector<int32_t> axisTensorData;
shapeNode->GetConstTensor(axisTensorData);
@@ -2630,34 +2627,35 @@ ParsedTfOperationPtr TfParser::ParseSplit(const tensorflow::NodeDef& nodeDef,
// Armnn supports split along the channel dimension for data formats NHWC and NCHW.
if (splitDim == 0 || splitDim == 2)
{
- throw ParseException(
- boost::str(
- boost::format(
+ throw armnn::ParseException(
+ boost::str(
+ boost::format(
"Dimension %1% for split is not supported by Armnn. "
"Node %2% %3%")
- % splitDim
- % nodeDef.name()
- % CHECK_LOCATION().AsString()));
+ % splitDim
+ % nodeDef.name()
+ % CHECK_LOCATION().AsString()));
}
// As Armnn only supports splitter outputs of the same shape, therefore num_splits will be limited to an integer.
uint32_t num_split = ReadMandatoryNodeUint32Attribute(nodeDef, "num_or_size_splits");
- IOutputSlot& inputSlot = inputs[0].m_IndexedValue->ResolveArmnnOutputSlot(inputs[0].m_Index);
+ IOutputSlot& inputSlot = inputs[1 - index].m_IndexedValue->ResolveArmnnOutputSlot(inputs[1 - index].m_Index);
TensorInfo inputTensorInfo = inputSlot.GetTensorInfo();
- if (inputTensorInfo.GetNumDimensions() != MaxNumOfTensorDimensions)
+ auto inputDimSize = inputTensorInfo.GetNumDimensions();
+
+ if (inputDimSize != MaxNumOfTensorDimensions)
{
throw armnn::ParseException(
- boost::str(
- boost::format(
+ boost::str(
+ boost::format(
"The number of dimensions: %1% for input tensors of the "
- "splitter op should be %2% %3%")
- % inputTensorInfo.GetNumDimensions()
- % MaxNumOfTensorDimensions
- % CHECK_LOCATION().AsString()));
+ "split op should be %2% %3%")
+ % inputTensorInfo.GetNumDimensions()
+ % MaxNumOfTensorDimensions
+ % CHECK_LOCATION().AsString()));
}
- auto inputDimSize = inputTensorInfo.GetNumDimensions();
std::vector<unsigned int> splitterDimSizes(inputDimSize);
diff --git a/src/armnnTfParser/TfParser.hpp b/src/armnnTfParser/TfParser.hpp
index 46da55f1d1..95ccf397c1 100644
--- a/src/armnnTfParser/TfParser.hpp
+++ b/src/armnnTfParser/TfParser.hpp
@@ -129,6 +129,8 @@ private:
template<typename Type>
bool HasParsedConstTensor(ParsedTfOperation* parsedTfOpPtr) const;
+ unsigned int GetConstInputIndex(const std::vector<OutputOfParsedTfOperation>& inputs);
+
ParsedTfOperationPtr ParseAdd(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
ParsedTfOperationPtr ParseAddN(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
ParsedTfOperationPtr ParseBiasAdd(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
diff --git a/src/armnnTfParser/test/Split.cpp b/src/armnnTfParser/test/Split.cpp
index de6b5d861e..87cd6544c9 100644
--- a/src/armnnTfParser/test/Split.cpp
+++ b/src/armnnTfParser/test/Split.cpp
@@ -11,93 +11,140 @@ BOOST_AUTO_TEST_SUITE(TensorflowParser)
struct SplitFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
{
- SplitFixture() {
- m_Prototext =
- "node { \n"
- " name: \"graphInput\" \n"
- " op: \"Placeholder\" \n"
- " attr { \n"
- " key: \"dtype\" \n"
- " value { \n"
- " type: DT_FLOAT \n"
- " } \n"
- " } \n"
- " attr { \n"
- " key: \"shape\" \n"
- " value { \n"
- " shape { \n"
- " } \n"
- " } \n"
- " } \n"
- " } \n"
- " node {"
- " name: \"splitInput\" \n"
- " op: \"Const\" \n"
- "attr {\n"
- " key: \"dtype\" \n"
- " value {"
- " type: DT_INT32"
- " }"
- "}"
- "attr {"
- " key: \"value\"\n"
- " value { "
- " tensor {"
- " dtype: DT_INT32"
- " tensor_shape {"
- "}"
- "int_val: 1"
- "}"
- "}"
- "}"
- "}"
- "node { \n"
- " name: \"Split\" \n"
- " op: \"Split\" \n"
- "input: \"graphInput\"\n"
- "input: \"splitInput\"\n"
- "attr { \n "
- "key: \"T\"\n"
- "value {\n"
- "type: DT_FLOAT\n"
- " }\n"
- "}\n"
- "\n"
- " attr { \n"
- " key: \"num_or_size_splits\" \n"
- " value { \n"
- " i:2 \n "
- " } \n"
- " } \n"
- "} \n"
- "node { \n"
- "name: \"Relu_1\"\n"
- "op: \"Relu\"\n"
- "input: \"Split:0\"\n"
- "attr { \n "
- "key: \"T\"\n"
- "value {\n"
- "type: DT_FLOAT\n"
- " }\n"
- "}\n"
- "}\n"
- "node { \n"
- "name: \"Relu_2\"\n"
- "op: \"Relu\"\n"
- "input: \"Split:1\"\n"
- "attr { \n "
- "key: \"T\"\n"
- "value {\n"
- "type: DT_FLOAT\n"
- " }\n"
- "}\n"
- "}\n";
+ SplitFixture(bool withDimZero=false) {
+ m_Prototext = R"(
+ node {
+ name: "graphInput"
+ op: "Placeholder"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ }
+ }
+ }
+ }
+ node {
+ name: "graphInput2"
+ op: "Placeholder"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ }
+ }
+ }
+ }
+ node {
+ name: "multiplication"
+ op : "Mul"
+ input: "graphInput"
+ input: "graphInput2"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ }
+ node {
+ name: "SplitInput"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ }
+ int_val: )";
- Setup( { { "graphInput", { 1, 2, 2 , 2} } },
+ if(withDimZero)
+ {
+ m_Prototext += std::to_string(3);
+ }
+ else
+ {
+ m_Prototext += std::to_string(1);
+ }
+
+ m_Prototext += R"(
+ }
+ }
+ }
+ }
+ node {
+ name: "Split"
+ op: "Split" )";
+ if(withDimZero)
+ {
+ m_Prototext += "input: \"SplitInput\"\n";
+ m_Prototext += "input: \"multiplication\"\n";
+ }
+ else
+ {
+ m_Prototext += "input: \"graphInput\"\n";
+ m_Prototext += "input: \"SplitInput\"\n";
+ }
+ m_Prototext += R"(
+ attr {
+ key: "num_or_size_splits"
+ value {
+ i: 2
+ }
+ }
+ }
+ node {
+ name: "Relu_1"
+ op: "Relu"
+ input: "Split:0"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ }
+ node {
+ name: "Relu_2"
+ op: "Relu"
+ input:"Split:1"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ } )";
+
+ Setup( { { "graphInput", { 1, 2, 2 , 2} } , { "graphInput2", { 1, 2, 2 , 2} }},
{ "Relu_1", "Relu_2" });
}
};
+struct InputFirstSplitFixture : SplitFixture
+{
+ InputFirstSplitFixture() : SplitFixture(true) {}
+};
+
BOOST_FIXTURE_TEST_CASE(ParseAxisOneSplitTwo, SplitFixture)
{
BOOST_TEST(
@@ -111,4 +158,19 @@ BOOST_FIXTURE_TEST_CASE(ParseAxisOneSplitTwo, SplitFixture)
{ "Relu_2", { 0.0f, 0.5f, 0.0f, 1.75f } } });
}
+BOOST_FIXTURE_TEST_CASE(ParseSplit, InputFirstSplitFixture)
+{
+
+ BOOST_TEST(
+ (m_Parser->GetNetworkOutputBindingInfo("Relu_1").second.GetShape() == armnn::TensorShape({ 1, 2, 2, 1 })));
+
+ BOOST_TEST(
+ (m_Parser->GetNetworkOutputBindingInfo("Relu_2").second.GetShape() == armnn::TensorShape({ 1, 2, 2, 1 })));
+
+ RunTest<4>({ { "graphInput", { -1.0f, -0.5f, 1.25f, -3.0f, 0.0f, 0.5f, -0.75f , 1.75f } } ,
+ { "graphInput2", { -1.0f, -0.5f, 1.25f, -3.0f, 0.0f, 0.5f, -0.75f , 1.75f } } },
+ { { "Relu_1", { 1.0f, 1.5625f, 0, 0.5625f } },
+ { "Relu_2", { 0.25, 9.0f, 0.25f, 3.0625f } } });
+}
+
BOOST_AUTO_TEST_SUITE_END()