aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/Layer.cpp
diff options
context:
space:
mode:
authortelsoa01 <telmo.soares@arm.com>2018-03-09 14:13:49 +0000
committertelsoa01 <telmo.soares@arm.com>2018-03-09 14:13:49 +0000
commit4fcda0101ec3d110c1d6d7bee5c83416b645528a (patch)
treec9a70aeb2887006160c1b3d265c27efadb7bdbae /src/armnn/Layer.cpp
downloadarmnn-4fcda0101ec3d110c1d6d7bee5c83416b645528a.tar.gz
Release 18.02
Change-Id: Id3c11dc5ee94ef664374a988fcc6901e9a232fa6
Diffstat (limited to 'src/armnn/Layer.cpp')
-rw-r--r--src/armnn/Layer.cpp220
1 files changed, 220 insertions, 0 deletions
diff --git a/src/armnn/Layer.cpp b/src/armnn/Layer.cpp
new file mode 100644
index 0000000000..20a8ba4926
--- /dev/null
+++ b/src/armnn/Layer.cpp
@@ -0,0 +1,220 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// See LICENSE file in the project root for full license information.
+//
+#include "Layer.hpp"
+
+#include "Graph.hpp"
+#include "backends/WorkloadData.hpp"
+
+#include <boost/cast.hpp>
+#include <boost/format.hpp>
+#include <boost/log/trivial.hpp>
+
+#include <numeric>
+
+namespace armnn
+{
+
+void InputSlot::Insert(Layer& layer)
+{
+ BOOST_ASSERT(layer.GetNumInputSlots() <= 1);
+ BOOST_ASSERT(layer.GetNumOutputSlots() == 1);
+
+ OutputSlot* const prevSlot = GetConnectedOutputSlot();
+
+ if (prevSlot != nullptr)
+ {
+ // Disconnect parent from this
+ prevSlot->Disconnect(*this);
+
+ // Connect inserted layer to parent
+ BOOST_ASSERT(layer.GetNumInputSlots() == 1);
+ prevSlot->Connect(layer.GetInputSlot(0));
+
+ // Set tensor info for inserted layer
+ const TensorInfo& tensorInfo = prevSlot->GetTensorInfo();
+ layer.GetOutputHandler().SetTensorInfo(tensorInfo);
+ }
+
+ // Connect inserted layer to this
+ layer.GetOutputSlot(0).Connect(*this);
+}
+
+const InputSlot* OutputSlot::GetConnection(unsigned int index) const
+{
+ ValidateConnectionIndex(index);
+ return m_Connections[index];
+}
+
+InputSlot* OutputSlot::GetConnection(unsigned int index)
+{
+ ValidateConnectionIndex(index);
+ return m_Connections[index];
+}
+
+void OutputSlot::SetTensorInfo(const TensorInfo& tensorInfo)
+{
+ GetOutputHandler().SetTensorInfo(tensorInfo);
+}
+
+const TensorInfo& OutputSlot::GetTensorInfo() const
+{
+ return GetOutputHandler().GetTensorInfo();
+}
+
+bool OutputSlot::IsTensorInfoSet() const
+{
+ return GetOutputHandler().IsTensorInfoSet();
+}
+
+bool OutputSlot::ValidateTensorShape(const TensorShape& shape) const
+{
+ BOOST_ASSERT_MSG(IsTensorInfoSet(), "TensorInfo must be set in order to validate the shape.");
+ return shape == m_OutputHandler.GetTensorInfo().GetShape();
+}
+
+int OutputSlot::Connect(InputSlot& destination)
+{
+ destination.SetConnection(this);
+ m_Connections.push_back(&destination);
+ return boost::numeric_cast<int>(m_Connections.size() - 1);
+}
+
+void OutputSlot::Disconnect(InputSlot& slot)
+{
+ slot.SetConnection(nullptr);
+ m_Connections.erase(std::remove(m_Connections.begin(), m_Connections.end(), &slot), m_Connections.end());
+}
+
+void OutputSlot::DisconnectAll()
+{
+ while (GetNumConnections() > 0)
+ {
+ InputSlot& connection = *GetConnection(0);
+ Disconnect(connection);
+ }
+}
+
+void OutputSlot::MoveAllConnections(OutputSlot& destination)
+{
+ while (GetNumConnections() > 0)
+ {
+ InputSlot& connection = *GetConnection(0);
+ Disconnect(connection);
+ destination.Connect(connection);
+ }
+}
+
+void OutputSlot::ValidateConnectionIndex(unsigned int index) const
+{
+ if (boost::numeric_cast<std::size_t>(index) >= m_Connections.size())
+ {
+ throw InvalidArgumentException(
+ boost::str(boost::format("GetConnection: Invalid index %1% provided") % index));
+ }
+}
+
+Layer::Layer(unsigned int numInputSlots, unsigned int numOutputSlots, LayerType type, const char* name)
+: m_OutputHandlers(numOutputSlots)
+, m_LayerName(name ? name : "")
+, m_Type(type)
+, m_ComputeDevice(Compute::Undefined)
+{
+ m_InputSlots.reserve(numInputSlots);
+ for (unsigned int i = 0; i < numInputSlots; ++i)
+ {
+ m_InputSlots.emplace_back(*this, i);
+ }
+
+ m_OutputSlots.reserve(numOutputSlots);
+ for (unsigned int i = 0; i < numOutputSlots; ++i)
+ {
+ m_OutputSlots.emplace_back(*this, m_OutputHandlers[i]);
+ }
+}
+
+void Layer::CollectWorkloadInputs(WorkloadDataCollector& dataCollector, const Graph& graph) const
+{
+ for (auto&& inputSlot : GetInputSlots())
+ {
+ // The graph must be well-formed at this point
+ BOOST_ASSERT(inputSlot.GetConnection());
+ const OutputHandler& outputHandler = inputSlot.GetConnectedOutputSlot()->GetOutputHandler();
+ dataCollector.Push(outputHandler.GetData(), outputHandler.GetTensorInfo());
+ }
+}
+
+void Layer::CollectWorkloadOutputs(WorkloadDataCollector& dataCollector, const Graph& graph) const
+{
+ for (auto&& outputHandler : m_OutputHandlers)
+ {
+ outputHandler.CollectWorkloadOutputs(dataCollector);
+ }
+}
+
+void Layer::CreateTensorHandles(Graph& graph, const IWorkloadFactory& factory)
+{
+ for (auto&& outputHandler : m_OutputHandlers)
+ {
+ outputHandler.CreateTensorHandles(factory);
+ }
+}
+
+DataType Layer::GetDataType() const
+{
+ if (GetNumInputSlots() > 0) // Ignore the input layer
+ {
+ return GetInputSlot(0).GetConnection()->GetTensorInfo().GetDataType();
+ }
+ return DataType::Float32;
+}
+
+void Layer::ResetPriority() const
+{
+ m_Priority = 0;
+ m_Visiting = false;
+}
+
+LayerPriority Layer::GetPriority() const
+{
+ constexpr LayerPriority inputPrio = std::numeric_limits<LayerPriority>::lowest();
+ constexpr LayerPriority outputPrio = std::numeric_limits<LayerPriority>::max();
+
+ if (GetType() == LayerType::Input)
+ {
+ m_Priority = inputPrio;
+ }
+ else if (GetType() == LayerType::Output)
+ {
+ m_Priority = outputPrio;
+ }
+ else if (m_Priority == 0)
+ {
+ if (m_Visiting)
+ {
+ throw GraphValidationException("Graph has circular dependencies: cannot walk");
+ }
+
+ auto maxPrio = [](const LayerPriority prio, const InputSlot& slot) -> LayerPriority
+ {
+ const Layer& input = slot.GetConnectedOutputSlot()->GetOwningLayer();
+ return std::max(prio, input.GetPriority());
+ };
+
+ m_Visiting = true;
+ LayerPriority parentPrio = std::accumulate(GetInputSlots().cbegin(), GetInputSlots().cend(), 0U, maxPrio);
+ m_Visiting = false;
+
+ if (parentPrio >= outputPrio)
+ {
+ throw GraphValidationException("Graph has too many edges");
+ }
+
+ m_Priority = parentPrio + 1U;
+ }
+
+ return m_Priority;
+}
+
+} // namespace armnn