aboutsummaryrefslogtreecommitdiff
path: root/src/armnnTfParser/TfParser.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnTfParser/TfParser.cpp')
-rwxr-xr-xsrc/armnnTfParser/TfParser.cpp104
1 files changed, 51 insertions, 53 deletions
diff --git a/src/armnnTfParser/TfParser.cpp b/src/armnnTfParser/TfParser.cpp
index 0410460059..1e304cbfd7 100755
--- a/src/armnnTfParser/TfParser.cpp
+++ b/src/armnnTfParser/TfParser.cpp
@@ -1158,6 +1158,23 @@ bool TfParser::HasParsedConstTensor(ParsedTfOperation* parsedTfOpPtr) const
return dynamic_cast<ParsedConstTfOperation<Type>*>(parsedTfOpPtr) != nullptr;
}
+unsigned int TfParser::GetConstInputIndex(const std::vector<OutputOfParsedTfOperation>& inputs)
+{
+ for (unsigned int i = 0; i < inputs.size(); i++)
+ {
+ if (HasParsedConstTensor<int32_t>(inputs[i].m_IndexedValue->GetNode().name()))
+ {
+ return i;
+ }
+ }
+ throw ParseException(
+ boost::str(
+ boost::format(
+ "ArmNN only supports operators with constant axis. %1%")
+ % CHECK_LOCATION().AsString()));
+
+}
+
ParsedTfOperationPtr TfParser::ParseConv2D(const tensorflow::NodeDef& nodeDef,
const tensorflow::GraphDef& graphDef)
{
@@ -2040,22 +2057,12 @@ ParsedTfOperationPtr TfParser::ParseConcat(const tensorflow::NodeDef& nodeDef,
std::vector<OutputOfParsedTfOperation> inputs = GetInputParsedTfOperationsChecked(nodeDef, numInputs);
- // The last input is the axis for concatenation.
- if (!HasParsedConstTensor<int32_t>(inputs[numInputs - 1].m_IndexedValue->GetNode().name()))
- {
- throw ParseException(
- boost::str(
- boost::format(
- "ArmNN only supports Concat with constant axis. "
- "Input %1%. Node %2% %3%")
- % inputs[numInputs - 1].m_IndexedValue->GetNode().name()
- % nodeDef.name()
- % CHECK_LOCATION().AsString()));
- }
+ // Constant tensor index
+ unsigned int index = GetConstInputIndex(inputs);
+ // Get the axis tensor data
ParsedConstTfOperation<int32_t>* shapeNode =
- boost::polymorphic_downcast<ParsedConstTfOperation<int32_t>*>(inputs[numInputs - 1].m_IndexedValue);
+ boost::polymorphic_downcast<ParsedConstTfOperation<int32_t>*>(inputs[index].m_IndexedValue);
- // Get the axis tensor data
std::vector<int32_t> axisTensorData;
shapeNode->GetConstTensor(axisTensorData);
@@ -2066,13 +2073,13 @@ ParsedTfOperationPtr TfParser::ParseConcat(const tensorflow::NodeDef& nodeDef,
if (concatDim == 0 || concatDim == 2)
{
throw ParseException(
- boost::str(
- boost::format(
+ boost::str(
+ boost::format(
"Dimension %1% for concatenation is not supported by Armnn. "
"Node %2% %3%")
- % concatDim
- % nodeDef.name()
- % CHECK_LOCATION().AsString()));
+ % concatDim
+ % nodeDef.name()
+ % CHECK_LOCATION().AsString()));
}
unsigned int numConcatViews = numInputs - 1;
@@ -2090,13 +2097,13 @@ ParsedTfOperationPtr TfParser::ParseConcat(const tensorflow::NodeDef& nodeDef,
if (inputTensorInfo.GetNumDimensions() != MaxNumOfTensorDimensions)
{
throw armnn::ParseException(
- boost::str(
- boost::format(
+ boost::str(
+ boost::format(
"The number of dimensions: %1% for input tensors of the "
"concatenation op should be %2% %3%")
- % inputTensorInfo.GetNumDimensions()
- % MaxNumOfTensorDimensions
- % CHECK_LOCATION().AsString()));
+ % inputTensorInfo.GetNumDimensions()
+ % MaxNumOfTensorDimensions
+ % CHECK_LOCATION().AsString()));
}
// Copy the input tensor shape to mergeDimSizes and initialize the view origin coordinates for the current input
@@ -2605,22 +2612,12 @@ ParsedTfOperationPtr TfParser::ParseSplit(const tensorflow::NodeDef& nodeDef,
unsigned int numInputs = static_cast<unsigned int>(nodes.size());
std::vector<OutputOfParsedTfOperation> inputs = GetInputParsedTfOperationsChecked(nodeDef, numInputs);
- // The last input is the axis for split operation.
- if (!HasParsedConstTensor<int32_t>(inputs[numInputs - 1].m_IndexedValue->GetNode().name()))
- {
- throw ParseException(
- boost::str(
- boost::format(
- "ArmNN only supports split with constant axis. "
- "Input %1%. Node %2% %3%")
- % inputs[numInputs - 1].m_IndexedValue->GetNode().name()
- % nodeDef.name()
- % CHECK_LOCATION().AsString()));
- }
+ // Constant tensor index
+ unsigned int index = GetConstInputIndex(inputs);
+ // Get the axis tensor data
ParsedConstTfOperation<int32_t>* shapeNode =
- boost::polymorphic_downcast<ParsedConstTfOperation<int32_t>*>(inputs[numInputs - 1].m_IndexedValue);
+ boost::polymorphic_downcast<ParsedConstTfOperation<int32_t>*>(inputs[index].m_IndexedValue);
- // Get the axis tensor data
std::vector<int32_t> axisTensorData;
shapeNode->GetConstTensor(axisTensorData);
@@ -2630,34 +2627,35 @@ ParsedTfOperationPtr TfParser::ParseSplit(const tensorflow::NodeDef& nodeDef,
// Armnn supports split along the channel dimension for data formats NHWC and NCHW.
if (splitDim == 0 || splitDim == 2)
{
- throw ParseException(
- boost::str(
- boost::format(
+ throw armnn::ParseException(
+ boost::str(
+ boost::format(
"Dimension %1% for split is not supported by Armnn. "
"Node %2% %3%")
- % splitDim
- % nodeDef.name()
- % CHECK_LOCATION().AsString()));
+ % splitDim
+ % nodeDef.name()
+ % CHECK_LOCATION().AsString()));
}
// As Armnn only supports splitter outputs of the same shape, therefore num_splits will be limited to an integer.
uint32_t num_split = ReadMandatoryNodeUint32Attribute(nodeDef, "num_or_size_splits");
- IOutputSlot& inputSlot = inputs[0].m_IndexedValue->ResolveArmnnOutputSlot(inputs[0].m_Index);
+ IOutputSlot& inputSlot = inputs[1 - index].m_IndexedValue->ResolveArmnnOutputSlot(inputs[1 - index].m_Index);
TensorInfo inputTensorInfo = inputSlot.GetTensorInfo();
- if (inputTensorInfo.GetNumDimensions() != MaxNumOfTensorDimensions)
+ auto inputDimSize = inputTensorInfo.GetNumDimensions();
+
+ if (inputDimSize != MaxNumOfTensorDimensions)
{
throw armnn::ParseException(
- boost::str(
- boost::format(
+ boost::str(
+ boost::format(
"The number of dimensions: %1% for input tensors of the "
- "splitter op should be %2% %3%")
- % inputTensorInfo.GetNumDimensions()
- % MaxNumOfTensorDimensions
- % CHECK_LOCATION().AsString()));
+ "split op should be %2% %3%")
+ % inputTensorInfo.GetNumDimensions()
+ % MaxNumOfTensorDimensions
+ % CHECK_LOCATION().AsString()));
}
- auto inputDimSize = inputTensorInfo.GetNumDimensions();
std::vector<unsigned int> splitterDimSizes(inputDimSize);