aboutsummaryrefslogtreecommitdiff
path: root/src/armnnTfLiteParser/TfLiteParser.cpp
diff options
context:
space:
mode:
authorBruno Goncalves <bruno.slackware@gmail.com>2018-12-27 14:21:43 -0200
committerMatthew Bentham <matthew.bentham@arm.com>2019-01-29 11:15:38 +0000
commit3d7efe9a21ad47f33d330240b3b901ad7d5a5a81 (patch)
tree3bfa3dc2ba52c8a6e092ebe78c8ed4059077a5fb /src/armnnTfLiteParser/TfLiteParser.cpp
parentd161ba0bc83fa14f7aea4c629ca3e6ea04a2dc34 (diff)
downloadarmnn-3d7efe9a21ad47f33d330240b3b901ad7d5a5a81.tar.gz
Added ConstantLayer support to TfLiteParser
Change-Id: Iecc4fe8208b442d9c872e56c3d47249f959c6cc1
Diffstat (limited to 'src/armnnTfLiteParser/TfLiteParser.cpp')
-rw-r--r--src/armnnTfLiteParser/TfLiteParser.cpp34
1 files changed, 34 insertions, 0 deletions
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp
index 8b2a818e6d..83f6950074 100644
--- a/src/armnnTfLiteParser/TfLiteParser.cpp
+++ b/src/armnnTfLiteParser/TfLiteParser.cpp
@@ -588,6 +588,7 @@ INetworkPtr TfLiteParser::CreateNetworkFromModel()
SetupInputLayers(subgraphIndex);
SetupOutputLayers(subgraphIndex);
+ SetupConstantLayers(subgraphIndex);
++subgraphIndex;
}
@@ -1742,6 +1743,39 @@ void TfLiteParser::SetupOutputLayers(size_t subgraphIndex)
}
}
+void TfLiteParser::SetupConstantLayers(size_t subgraphIndex)
+{
+ CHECK_SUBGRAPH(m_Model, subgraphIndex);
+
+ const auto & subGraphPtr = m_Model->subgraphs[subgraphIndex];
+ for (unsigned int subgraphIndex = 0; subgraphIndex < m_SubgraphConnections.size(); ++subgraphIndex)
+ {
+ for (unsigned int tensorIndex = 0; tensorIndex < m_SubgraphConnections[subgraphIndex].size(); ++tensorIndex)
+ {
+ if (m_SubgraphConnections[subgraphIndex][tensorIndex].outputSlot == nullptr &&
+ m_SubgraphConnections[subgraphIndex][tensorIndex].inputSlots.size() > 0)
+ {
+ TensorRawPtr tensorPtr = subGraphPtr->tensors[tensorIndex].get();
+ armnn::TensorInfo tensorInfo = ToTensorInfo(tensorPtr);
+ auto tensorAndData = CreateConstTensor(tensorPtr,
+ tensorInfo,
+ armnn::Optional<armnn::PermutationVector&>());
+
+ std::string layerName = boost::str(boost::format("Constant:%1%") % tensorPtr->name);
+ IConnectableLayer *layer =
+ m_Network->AddConstantLayer(tensorAndData.first, layerName.c_str());
+
+ layer->GetOutputSlot(0).SetTensorInfo(tensorInfo);
+ RegisterOutputSlots(subgraphIndex,
+ VIRTUAL_OPERATOR_ID,
+ layer,
+ { tensorIndex });
+
+ }
+ }
+ }
+}
+
// example usage: BufferRawPtr bufferPtr = GetBuffer(m_Model, inputs[0]->buffer);
TfLiteParser::BufferRawPtr TfLiteParser::GetBuffer(const ModelPtr& model, size_t bufferIndex)
{