aboutsummaryrefslogtreecommitdiff
path: root/src/armnnTfParser/TfParser.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnTfParser/TfParser.cpp')
-rwxr-xr-xsrc/armnnTfParser/TfParser.cpp95
1 files changed, 94 insertions, 1 deletions
diff --git a/src/armnnTfParser/TfParser.cpp b/src/armnnTfParser/TfParser.cpp
index d65af2365b..b5a421145a 100755
--- a/src/armnnTfParser/TfParser.cpp
+++ b/src/armnnTfParser/TfParser.cpp
@@ -376,7 +376,9 @@ const std::map<std::string, TfParser::OperationParsingFunction> TfParser::ms_Ope
{ "Minimum", &TfParser::ParseMinimum },
{ "Equal", &TfParser::ParseEqual },
{ "Pad", &TfParser::ParsePad },
- { "Sub", &TfParser::ParseSub }
+ { "Sub", &TfParser::ParseSub },
+ { "Pack" , &TfParser::ParseStack },
+ { "Stack", &TfParser::ParseStack }
};
const std::list<std::string> TfParser::m_ControlInputs = {
@@ -1961,6 +1963,97 @@ ParsedTfOperationPtr TfParser::ParseSub(const tensorflow::NodeDef& nodeDef, cons
return std::make_unique<SingleLayerParsedTfOperation>(this, nodeDef, layer);
}
+ParsedTfOperationPtr TfParser::ParseStack(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef)
+{
+ boost::ignore_unused(graphDef);
+ std::vector<OutputOfConstNodeDef> nodes = GetTfInputNodes(nodeDef);
+
+ unsigned int numInputs = static_cast<unsigned int>(nodes.size());
+ if (numInputs < 1)
+ {
+ throw ParseException(
+ boost::str(
+ boost::format(
+ "Pack/Stack expects at least one input. Got %1% for Node %2% %3%")
+ % numInputs
+ % nodeDef.name()
+ % CHECK_LOCATION().AsString()));
+ }
+
+ std::vector<OutputOfParsedTfOperation> inputs = GetInputParsedTfOperationsChecked(nodeDef, numInputs);
+ // Use the tensor shape of the first input as the "correct" input shape in the descriptor
+ IOutputSlot* input0Slot = &inputs[0].m_IndexedValue->ResolveArmnnOutputSlot(inputs[0].m_Index);
+ const TensorInfo& inputTensorInfo = input0Slot->GetTensorInfo();
+ auto numDimensions = inputTensorInfo.GetShape().GetNumDimensions();
+
+ // validate axis
+ int32_t axis = ReadMandatoryNodeInt32Attribute(nodeDef, "axis");
+ const int sNumDimensions = (static_cast<int>(numDimensions) + 1);
+ if (!(axis < sNumDimensions && axis >= -sNumDimensions))
+ {
+ throw ParseException(
+ boost::str(
+ boost::format(
+ "Axis index is not in range. Got %1% for Node %2% %3%")
+ % axis
+ % nodeDef.name()
+ % CHECK_LOCATION().AsString()));
+ }
+
+ if (axis < 0)
+ {
+ axis = static_cast<int32_t>(numDimensions) + axis + 1;
+ }
+
+ StackDescriptor stackDescriptor;
+ stackDescriptor.m_Axis = static_cast<uint32_t>(axis);
+ stackDescriptor.m_NumInputs = static_cast<uint32_t>(numInputs);
+ stackDescriptor.m_InputShape = inputTensorInfo.GetShape();
+
+ const unsigned int supportedNumDims = 4;
+ for (unsigned int viewIndex = 0; viewIndex < numInputs; ++viewIndex)
+ {
+ IOutputSlot& inputSlot = inputs[viewIndex].m_IndexedValue->ResolveArmnnOutputSlot(inputs[viewIndex].m_Index);
+ TensorInfo inputTensorInfo = inputSlot.GetTensorInfo();
+
+ // Double check dimensions of the tensors
+ if (inputTensorInfo.GetNumDimensions() >= supportedNumDims)
+ {
+ throw armnn::ParseException(
+ boost::str(
+ boost::format(
+ "The number of dimensions: %1% for input tensors of the "
+ "Pack/Stack op. Number of dimensions should be less than %2% %3%")
+ % inputTensorInfo.GetNumDimensions()
+ % supportedNumDims
+ % CHECK_LOCATION().AsString()));
+ }
+ }
+
+ std::vector<unsigned int> outputDimensions;
+ for (unsigned int i = 0; i < stackDescriptor.m_InputShape.GetNumDimensions(); ++i)
+ {
+ outputDimensions.push_back(stackDescriptor.m_InputShape[i]);
+ }
+ outputDimensions.insert(outputDimensions.begin() + axis, numInputs);
+
+ // add Stack Layer
+ IConnectableLayer* const layer = m_Network->AddStackLayer(stackDescriptor, nodeDef.name().c_str());
+
+ for (unsigned int viewIndex = 0; viewIndex < numInputs; ++viewIndex)
+ {
+ IOutputSlot& inputSlot = inputs[viewIndex].m_IndexedValue->ResolveArmnnOutputSlot(inputs[viewIndex].m_Index);
+ inputSlot.Connect(layer->GetInputSlot(viewIndex));
+ }
+
+ layer->GetOutputSlot(0).SetTensorInfo(
+ armnn::TensorInfo(static_cast<uint32_t>(outputDimensions.size()),
+ outputDimensions.data(),
+ inputTensorInfo.GetDataType()));
+
+ return std::make_unique<SingleLayerParsedTfOperation>(this, nodeDef, layer);
+}
+
unsigned int CheckPaddingTensor(const ConstTensor& paddingTensor,
const TensorInfo& inputTensorInfo,
const std::string& nodeName)