diff options
Diffstat (limited to 'src/armnnTfParser/TfParser.cpp')
-rwxr-xr-x[-rw-r--r--] | src/armnnTfParser/TfParser.cpp | 104 |
1 files changed, 104 insertions, 0 deletions
diff --git a/src/armnnTfParser/TfParser.cpp b/src/armnnTfParser/TfParser.cpp index 7a213c0909..2d31842205 100644..100755 --- a/src/armnnTfParser/TfParser.cpp +++ b/src/armnnTfParser/TfParser.cpp @@ -350,6 +350,7 @@ const std::map<std::string, TfParser::OperationParsingFunction> TfParser::ms_Ope { "Sigmoid", &TfParser::ParseSigmoid }, { "Softmax", &TfParser::ParseSoftmax }, { "Softplus", &TfParser::ParseSoftplus }, + { "Split", &TfParser::ParseSplit }, { "Tanh", &TfParser::ParseTanh }, { "MaxPool", &TfParser::ParseMaxPool }, { "AvgPool", &TfParser::ParseAvgPool }, @@ -2461,6 +2462,109 @@ ParsedTfOperationPtr TfParser::ParseSoftmax(const tensorflow::NodeDef& nodeDef, return std::make_unique<SingleLayerParsedTfOperation>(this, nodeDef, layer); } +ParsedTfOperationPtr TfParser::ParseSplit(const tensorflow::NodeDef& nodeDef, + const tensorflow::GraphDef& graphDef) +{ + boost::ignore_unused(graphDef); + + std::vector<OutputOfConstNodeDef> nodes = GetTfInputNodes(nodeDef); + unsigned int numInputs = static_cast<unsigned int>(nodes.size()); + std::vector<OutputOfParsedTfOperation> inputs = GetInputParsedTfOperationsChecked(nodeDef, numInputs); + + // The last input is the axis for split operation. + if (!HasParsedConstTensor<int32_t>(inputs[numInputs - 1].m_IndexedValue->GetNode().name())) + { + throw ParseException( + boost::str( + boost::format( + "ArmNN only supports split with constant axis. " + "Input %1%. Node %2% %3%") + % inputs[numInputs - 1].m_IndexedValue->GetNode().name() + % nodeDef.name() + % CHECK_LOCATION().AsString())); + } + ParsedConstTfOperation<int32_t>* shapeNode = + boost::polymorphic_downcast<ParsedConstTfOperation<int32_t>*>(inputs[numInputs - 1].m_IndexedValue); + + // Get the axis tensor data + std::vector<int32_t> axisTensorData; + shapeNode->GetConstTensor(axisTensorData); + + // This splitDim indicates the data format: 3 is the NHWC, 1 is the NCHW. + const unsigned int splitDim = static_cast<unsigned int>(axisTensorData[0]); + + // Armnn supports split along the channel dimension for data formats NHWC and NCHW. + if (splitDim == 0 || splitDim == 2) + { + throw ParseException( + boost::str( + boost::format( + "Dimension %1% for split is not supported by Armnn. " + "Node %2% %3%") + % splitDim + % nodeDef.name() + % CHECK_LOCATION().AsString())); + } + + // As Armnn only supports splitter outputs of the same shape, therefore num_splits will be limited to an integer. + uint32_t num_split = ReadMandatoryNodeUint32Attribute(nodeDef, "num_or_size_splits"); + + IOutputSlot& inputSlot = inputs[0].m_IndexedValue->ResolveArmnnOutputSlot(inputs[0].m_Index); + TensorInfo inputTensorInfo = inputSlot.GetTensorInfo(); + + if (inputTensorInfo.GetNumDimensions() != MaxNumOfTensorDimensions) + { + throw armnn::ParseException( + boost::str( + boost::format( + "The number of dimensions: %1% for input tensors of the " + "splitter op should be %2% %3%") + % inputTensorInfo.GetNumDimensions() + % MaxNumOfTensorDimensions + % CHECK_LOCATION().AsString())); + } + auto inputDimSize = inputTensorInfo.GetNumDimensions(); + + std::vector<unsigned int> splitterDimSizes(inputDimSize); + + // Add current input shape to splitterDimSizes + for (unsigned int i = 0; i < inputDimSize; ++i) + { + splitterDimSizes[i] = inputTensorInfo.GetShape()[i]; + } + + if (splitterDimSizes[splitDim] % num_split != 0) + { + throw ParseException("Number of splits must evenly divide the dimension"); + } + splitterDimSizes[splitDim] /= num_split; + + SplitterDescriptor splitDesc(num_split); + for (unsigned int g = 0; g < num_split; ++g) + { + // Set the size of the views. + for (unsigned int dimIdx = 0; dimIdx < splitterDimSizes.size(); ++dimIdx) + { + splitDesc.SetViewSize(g, dimIdx, splitterDimSizes[dimIdx]); + } + splitDesc.SetViewOriginCoord(g, splitDim, splitterDimSizes[splitDim] * g); + } + + IConnectableLayer *layer = m_Network->AddSplitterLayer(splitDesc, nodeDef.name().c_str()); + + inputSlot.Connect(layer->GetInputSlot(0)); + + TensorShape outShape = TensorShape(static_cast<unsigned int>(splitterDimSizes.size()), + splitterDimSizes.data()); + + for (unsigned int i = 0; i < layer->GetNumOutputSlots(); ++i) + { + layer->GetOutputSlot(i).SetTensorInfo(armnn::TensorInfo(outShape, inputTensorInfo.GetDataType())); + } + + return std::make_unique<SingleLayerParsedTfOperation>(this, nodeDef, layer); +} + ParsedTfOperationPtr TfParser::ParseSoftplus(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef) { |