aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/SubgraphView.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/SubgraphView.cpp')
-rw-r--r--src/armnn/SubgraphView.cpp219
1 files changed, 219 insertions, 0 deletions
diff --git a/src/armnn/SubgraphView.cpp b/src/armnn/SubgraphView.cpp
new file mode 100644
index 0000000000..23f969d8ee
--- /dev/null
+++ b/src/armnn/SubgraphView.cpp
@@ -0,0 +1,219 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "SubgraphView.hpp"
+#include "Graph.hpp"
+
+#include <boost/numeric/conversion/cast.hpp>
+
+#include <utility>
+
+namespace armnn
+{
+
+namespace
+{
+
+template <class C>
+void AssertIfNullsOrDuplicates(const C& container, const std::string& errorMessage)
+{
+ using T = typename C::value_type;
+ std::unordered_set<T> duplicateSet;
+ std::for_each(container.begin(), container.end(), [&duplicateSet, &errorMessage](const T& i)
+ {
+ // Ignore unused for release builds
+ boost::ignore_unused(errorMessage);
+
+ // Check if the item is valid
+ BOOST_ASSERT_MSG(i, errorMessage.c_str());
+
+ // Check if a duplicate has been found
+ BOOST_ASSERT_MSG(duplicateSet.find(i) == duplicateSet.end(), errorMessage.c_str());
+
+ duplicateSet.insert(i);
+ });
+}
+
+} // anonymous namespace
+
+SubgraphView::SubgraphView(Graph& graph)
+ : m_InputSlots{}
+ , m_OutputSlots{}
+ , m_Layers(graph.begin(), graph.end())
+ , m_ParentGraph(&graph)
+{
+ CheckSubgraph();
+}
+
+SubgraphView::SubgraphView(Graph* parentGraph, InputSlots&& inputs, OutputSlots&& outputs, Layers&& layers)
+ : m_InputSlots{inputs}
+ , m_OutputSlots{outputs}
+ , m_Layers{layers}
+ , m_ParentGraph(parentGraph)
+{
+ CheckSubgraph();
+}
+
+SubgraphView::SubgraphView(const SubgraphView& referenceSubgraph,
+ InputSlots&& inputs,
+ OutputSlots&& outputs,
+ Layers&& layers)
+ : m_InputSlots{inputs}
+ , m_OutputSlots{outputs}
+ , m_Layers{layers}
+ , m_ParentGraph(referenceSubgraph.m_ParentGraph)
+{
+ CheckSubgraph();
+}
+
+SubgraphView::SubgraphView(const SubgraphView& subgraph)
+ : m_InputSlots(subgraph.m_InputSlots.begin(), subgraph.m_InputSlots.end())
+ , m_OutputSlots(subgraph.m_OutputSlots.begin(), subgraph.m_OutputSlots.end())
+ , m_Layers(subgraph.m_Layers.begin(), subgraph.m_Layers.end())
+ , m_ParentGraph(subgraph.m_ParentGraph)
+{
+ CheckSubgraph();
+}
+
+SubgraphView::SubgraphView(SubgraphView&& subgraph)
+ : m_InputSlots(std::move(subgraph.m_InputSlots))
+ , m_OutputSlots(std::move(subgraph.m_OutputSlots))
+ , m_Layers(std::move(subgraph.m_Layers))
+ , m_ParentGraph(std::exchange(subgraph.m_ParentGraph, nullptr))
+{
+ CheckSubgraph();
+}
+
+SubgraphView::SubgraphView(const SubgraphView& referenceSubgraph, IConnectableLayer* layer)
+ : m_InputSlots{}
+ , m_OutputSlots{}
+ , m_Layers{boost::polymorphic_downcast<Layer*>(layer)}
+ , m_ParentGraph(referenceSubgraph.m_ParentGraph)
+{
+ unsigned int numInputSlots = layer->GetNumInputSlots();
+ m_InputSlots.resize(numInputSlots);
+ for (unsigned int i = 0; i < numInputSlots; i++)
+ {
+ m_InputSlots.at(i) = boost::polymorphic_downcast<InputSlot*>(&(layer->GetInputSlot(i)));
+ }
+
+ unsigned int numOutputSlots = layer->GetNumOutputSlots();
+ m_OutputSlots.resize(numOutputSlots);
+ for (unsigned int i = 0; i < numOutputSlots; i++)
+ {
+ m_OutputSlots.at(i) = boost::polymorphic_downcast<OutputSlot*>(&(layer->GetOutputSlot(i)));
+ }
+
+ CheckSubgraph();
+}
+
+void SubgraphView::CheckSubgraph()
+{
+ // Check that the sub-graph has a valid parent graph
+ BOOST_ASSERT_MSG(m_ParentGraph, "Sub-graphs must have a parent graph");
+
+ // Check for invalid or duplicate input slots
+ AssertIfNullsOrDuplicates(m_InputSlots, "Sub-graphs cannot contain null or duplicate input slots");
+
+ // Check for invalid or duplicate output slots
+ AssertIfNullsOrDuplicates(m_OutputSlots, "Sub-graphs cannot contain null or duplicate output slots");
+
+ // Check for invalid or duplicate layers
+ AssertIfNullsOrDuplicates(m_Layers, "Sub-graphs cannot contain null or duplicate layers");
+
+ // Check that all the layers of the sub-graph belong to the parent graph
+ std::for_each(m_Layers.begin(), m_Layers.end(), [&](const Layer* l)
+ {
+ BOOST_ASSERT_MSG(std::find(m_ParentGraph->begin(), m_ParentGraph->end(), l) != m_ParentGraph->end(),
+ "Sub-graph layer is not a member of the parent graph");
+ });
+}
+
+void SubgraphView::Update(Graph &graph)
+{
+ m_InputSlots.clear();
+ m_OutputSlots.clear();
+ m_Layers.assign(graph.begin(), graph.end());
+ m_ParentGraph = &graph;
+
+ CheckSubgraph();
+}
+
+const SubgraphView::InputSlots& SubgraphView::GetInputSlots() const
+{
+ return m_InputSlots;
+}
+
+const SubgraphView::OutputSlots& SubgraphView::GetOutputSlots() const
+{
+ return m_OutputSlots;
+}
+
+const InputSlot* SubgraphView::GetInputSlot(unsigned int index) const
+{
+ return m_InputSlots.at(index);
+}
+
+InputSlot* SubgraphView::GetInputSlot(unsigned int index)
+{
+ return m_InputSlots.at(index);
+}
+
+const OutputSlot* SubgraphView::GetOutputSlot(unsigned int index) const
+{
+ return m_OutputSlots.at(index);
+}
+
+OutputSlot* SubgraphView::GetOutputSlot(unsigned int index)
+{
+ return m_OutputSlots.at(index);
+}
+
+unsigned int SubgraphView::GetNumInputSlots() const
+{
+ return boost::numeric_cast<unsigned int>(m_InputSlots.size());
+}
+
+unsigned int SubgraphView::GetNumOutputSlots() const
+{
+ return boost::numeric_cast<unsigned int>(m_OutputSlots.size());
+}
+
+const SubgraphView::Layers & SubgraphView::GetLayers() const
+{
+ return m_Layers;
+}
+
+SubgraphView::Layers::iterator SubgraphView::begin()
+{
+ return m_Layers.begin();
+}
+
+SubgraphView::Iterator SubgraphView::end()
+{
+ return m_Layers.end();
+}
+
+SubgraphView::ConstIterator SubgraphView::begin() const
+{
+ return m_Layers.begin();
+}
+
+SubgraphView::ConstIterator SubgraphView::end() const
+{
+ return m_Layers.end();
+}
+
+SubgraphView::ConstIterator SubgraphView::cbegin() const
+{
+ return begin();
+}
+
+SubgraphView::ConstIterator SubgraphView::cend() const
+{
+ return end();
+}
+
+} // namespace armnn