aboutsummaryrefslogtreecommitdiff
path: root/src/armnnTfParser/TfParser.cpp
diff options
context:
space:
mode:
authorSadik Armagan <sadik.armagan@arm.com>2018-12-27 11:23:44 +0000
committerSaoirse Stewart Arm <saoirse.stewart@arm.com>2019-01-07 10:42:57 +0000
commit2ad6cb486164ff3aabe4e9ecabc47f08da48da35 (patch)
tree57ad464aa77179d9e93d7e0c26830d67464667a6 /src/armnnTfParser/TfParser.cpp
parent747ef82c88f9afe14a8b80b6b3b34118353e97f2 (diff)
downloadarmnn-2ad6cb486164ff3aabe4e9ecabc47f08da48da35.tar.gz
IVGCVSW-2384 Add Split parser function to Tensor flow parser
* Added Unit test * Updated TensorFlowSupport.md file Change-Id: I5f07de5e91ffb681c0ad17c7c73ee0326e7f1e0a
Diffstat (limited to 'src/armnnTfParser/TfParser.cpp')
-rwxr-xr-x[-rw-r--r--]src/armnnTfParser/TfParser.cpp104
1 files changed, 104 insertions, 0 deletions
diff --git a/src/armnnTfParser/TfParser.cpp b/src/armnnTfParser/TfParser.cpp
index 7a213c0909..2d31842205 100644..100755
--- a/src/armnnTfParser/TfParser.cpp
+++ b/src/armnnTfParser/TfParser.cpp
@@ -350,6 +350,7 @@ const std::map<std::string, TfParser::OperationParsingFunction> TfParser::ms_Ope
{ "Sigmoid", &TfParser::ParseSigmoid },
{ "Softmax", &TfParser::ParseSoftmax },
{ "Softplus", &TfParser::ParseSoftplus },
+ { "Split", &TfParser::ParseSplit },
{ "Tanh", &TfParser::ParseTanh },
{ "MaxPool", &TfParser::ParseMaxPool },
{ "AvgPool", &TfParser::ParseAvgPool },
@@ -2461,6 +2462,109 @@ ParsedTfOperationPtr TfParser::ParseSoftmax(const tensorflow::NodeDef& nodeDef,
return std::make_unique<SingleLayerParsedTfOperation>(this, nodeDef, layer);
}
+ParsedTfOperationPtr TfParser::ParseSplit(const tensorflow::NodeDef& nodeDef,
+ const tensorflow::GraphDef& graphDef)
+{
+ boost::ignore_unused(graphDef);
+
+ std::vector<OutputOfConstNodeDef> nodes = GetTfInputNodes(nodeDef);
+ unsigned int numInputs = static_cast<unsigned int>(nodes.size());
+ std::vector<OutputOfParsedTfOperation> inputs = GetInputParsedTfOperationsChecked(nodeDef, numInputs);
+
+ // The last input is the axis for split operation.
+ if (!HasParsedConstTensor<int32_t>(inputs[numInputs - 1].m_IndexedValue->GetNode().name()))
+ {
+ throw ParseException(
+ boost::str(
+ boost::format(
+ "ArmNN only supports split with constant axis. "
+ "Input %1%. Node %2% %3%")
+ % inputs[numInputs - 1].m_IndexedValue->GetNode().name()
+ % nodeDef.name()
+ % CHECK_LOCATION().AsString()));
+ }
+ ParsedConstTfOperation<int32_t>* shapeNode =
+ boost::polymorphic_downcast<ParsedConstTfOperation<int32_t>*>(inputs[numInputs - 1].m_IndexedValue);
+
+ // Get the axis tensor data
+ std::vector<int32_t> axisTensorData;
+ shapeNode->GetConstTensor(axisTensorData);
+
+ // This splitDim indicates the data format: 3 is the NHWC, 1 is the NCHW.
+ const unsigned int splitDim = static_cast<unsigned int>(axisTensorData[0]);
+
+ // Armnn supports split along the channel dimension for data formats NHWC and NCHW.
+ if (splitDim == 0 || splitDim == 2)
+ {
+ throw ParseException(
+ boost::str(
+ boost::format(
+ "Dimension %1% for split is not supported by Armnn. "
+ "Node %2% %3%")
+ % splitDim
+ % nodeDef.name()
+ % CHECK_LOCATION().AsString()));
+ }
+
+ // As Armnn only supports splitter outputs of the same shape, therefore num_splits will be limited to an integer.
+ uint32_t num_split = ReadMandatoryNodeUint32Attribute(nodeDef, "num_or_size_splits");
+
+ IOutputSlot& inputSlot = inputs[0].m_IndexedValue->ResolveArmnnOutputSlot(inputs[0].m_Index);
+ TensorInfo inputTensorInfo = inputSlot.GetTensorInfo();
+
+ if (inputTensorInfo.GetNumDimensions() != MaxNumOfTensorDimensions)
+ {
+ throw armnn::ParseException(
+ boost::str(
+ boost::format(
+ "The number of dimensions: %1% for input tensors of the "
+ "splitter op should be %2% %3%")
+ % inputTensorInfo.GetNumDimensions()
+ % MaxNumOfTensorDimensions
+ % CHECK_LOCATION().AsString()));
+ }
+ auto inputDimSize = inputTensorInfo.GetNumDimensions();
+
+ std::vector<unsigned int> splitterDimSizes(inputDimSize);
+
+ // Add current input shape to splitterDimSizes
+ for (unsigned int i = 0; i < inputDimSize; ++i)
+ {
+ splitterDimSizes[i] = inputTensorInfo.GetShape()[i];
+ }
+
+ if (splitterDimSizes[splitDim] % num_split != 0)
+ {
+ throw ParseException("Number of splits must evenly divide the dimension");
+ }
+ splitterDimSizes[splitDim] /= num_split;
+
+ SplitterDescriptor splitDesc(num_split);
+ for (unsigned int g = 0; g < num_split; ++g)
+ {
+ // Set the size of the views.
+ for (unsigned int dimIdx = 0; dimIdx < splitterDimSizes.size(); ++dimIdx)
+ {
+ splitDesc.SetViewSize(g, dimIdx, splitterDimSizes[dimIdx]);
+ }
+ splitDesc.SetViewOriginCoord(g, splitDim, splitterDimSizes[splitDim] * g);
+ }
+
+ IConnectableLayer *layer = m_Network->AddSplitterLayer(splitDesc, nodeDef.name().c_str());
+
+ inputSlot.Connect(layer->GetInputSlot(0));
+
+ TensorShape outShape = TensorShape(static_cast<unsigned int>(splitterDimSizes.size()),
+ splitterDimSizes.data());
+
+ for (unsigned int i = 0; i < layer->GetNumOutputSlots(); ++i)
+ {
+ layer->GetOutputSlot(i).SetTensorInfo(armnn::TensorInfo(outShape, inputTensorInfo.GetDataType()));
+ }
+
+ return std::make_unique<SingleLayerParsedTfOperation>(this, nodeDef, layer);
+}
+
ParsedTfOperationPtr TfParser::ParseSoftplus(const tensorflow::NodeDef& nodeDef,
const tensorflow::GraphDef& graphDef)
{