diff options
Diffstat (limited to 'src/armnnTfLiteParser/TfLiteParser.cpp')
-rw-r--r-- | src/armnnTfLiteParser/TfLiteParser.cpp | 34 |
1 files changed, 34 insertions, 0 deletions
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp index 8b2a818e6d..83f6950074 100644 --- a/src/armnnTfLiteParser/TfLiteParser.cpp +++ b/src/armnnTfLiteParser/TfLiteParser.cpp @@ -588,6 +588,7 @@ INetworkPtr TfLiteParser::CreateNetworkFromModel() SetupInputLayers(subgraphIndex); SetupOutputLayers(subgraphIndex); + SetupConstantLayers(subgraphIndex); ++subgraphIndex; } @@ -1742,6 +1743,39 @@ void TfLiteParser::SetupOutputLayers(size_t subgraphIndex) } } +void TfLiteParser::SetupConstantLayers(size_t subgraphIndex) +{ + CHECK_SUBGRAPH(m_Model, subgraphIndex); + + const auto & subGraphPtr = m_Model->subgraphs[subgraphIndex]; + for (unsigned int subgraphIndex = 0; subgraphIndex < m_SubgraphConnections.size(); ++subgraphIndex) + { + for (unsigned int tensorIndex = 0; tensorIndex < m_SubgraphConnections[subgraphIndex].size(); ++tensorIndex) + { + if (m_SubgraphConnections[subgraphIndex][tensorIndex].outputSlot == nullptr && + m_SubgraphConnections[subgraphIndex][tensorIndex].inputSlots.size() > 0) + { + TensorRawPtr tensorPtr = subGraphPtr->tensors[tensorIndex].get(); + armnn::TensorInfo tensorInfo = ToTensorInfo(tensorPtr); + auto tensorAndData = CreateConstTensor(tensorPtr, + tensorInfo, + armnn::Optional<armnn::PermutationVector&>()); + + std::string layerName = boost::str(boost::format("Constant:%1%") % tensorPtr->name); + IConnectableLayer *layer = + m_Network->AddConstantLayer(tensorAndData.first, layerName.c_str()); + + layer->GetOutputSlot(0).SetTensorInfo(tensorInfo); + RegisterOutputSlots(subgraphIndex, + VIRTUAL_OPERATOR_ID, + layer, + { tensorIndex }); + + } + } + } +} + // example usage: BufferRawPtr bufferPtr = GetBuffer(m_Model, inputs[0]->buffer); TfLiteParser::BufferRawPtr TfLiteParser::GetBuffer(const ModelPtr& model, size_t bufferIndex) { |