aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/Layer.cpp
diff options
context:
space:
mode:
authorDerek Lamberti <derek.lamberti@arm.com>2019-06-13 11:40:08 +0100
committerDerek Lamberti <derek.lamberti@arm.com>2019-06-24 15:00:15 +0000
commit84da38b0f11ca3db0a439e510514be780f3933ff (patch)
tree56532f4842abc1ad00ae57bc20ddc72cada59b4c /src/armnn/Layer.cpp
parent9515c7ec4f4535fff2c8f2d3f88974474d3f3468 (diff)
downloadarmnn-84da38b0f11ca3db0a439e510514be780f3933ff.tar.gz
IVGCVSW-3277 Refactor TensorHandle factory API
* Added backend support for multiple types of TensorHandle factories * Refactored the backend API to enable new tensor strategies * Added mechanism to determine memory strategies during optimization * Perform mem-copy only when Direct access is not found * Explicitly deleted the copy-constructor from OutputSlot to prevent accidental local copies that would cause the DisconnectAll to be called by the destructor Change-Id: I7e812c8e5e6c1c20db1c5932749ac70fd93db7f8 Signed-off-by: Derek Lamberti <derek.lamberti@arm.com> Signed-off-by: Matteo Martincigh <matteo.martincigh@arm.com>
Diffstat (limited to 'src/armnn/Layer.cpp')
-rw-r--r--src/armnn/Layer.cpp59
1 files changed, 55 insertions, 4 deletions
diff --git a/src/armnn/Layer.cpp b/src/armnn/Layer.cpp
index ced87b095c..a287220702 100644
--- a/src/armnn/Layer.cpp
+++ b/src/armnn/Layer.cpp
@@ -30,7 +30,8 @@ void InputSlot::Insert(Layer& layer)
// Connects inserted layer to parent.
BOOST_ASSERT(layer.GetNumInputSlots() == 1);
- prevSlot->Connect(layer.GetInputSlot(0));
+ int idx = prevSlot->Connect(layer.GetInputSlot(0));
+ prevSlot->SetMemoryStrategy(boost::numeric_cast<unsigned int>(idx), MemoryStrategy::Undefined);
// Sets tensor info for inserted layer.
const TensorInfo& tensorInfo = prevSlot->GetTensorInfo();
@@ -39,6 +40,7 @@ void InputSlot::Insert(Layer& layer)
// Connects inserted layer to this.
layer.GetOutputSlot(0).Connect(*this);
+ layer.GetOutputSlot(0).SetMemoryStrategy(0, MemoryStrategy::Undefined);
}
const InputSlot* OutputSlot::GetConnection(unsigned int index) const
@@ -78,13 +80,24 @@ int OutputSlot::Connect(InputSlot& destination)
{
destination.SetConnection(this);
m_Connections.push_back(&destination);
+ m_MemoryStrategies.push_back(MemoryStrategy::Undefined);
return boost::numeric_cast<int>(m_Connections.size() - 1);
}
void OutputSlot::Disconnect(InputSlot& slot)
{
slot.SetConnection(nullptr);
+ auto it = std::find(m_Connections.begin(), m_Connections.end(), &slot);
+
+ if (it == m_Connections.end())
+ {
+ return;
+ }
+
+ auto idx = std::distance(m_Connections.begin(), it);
m_Connections.erase(std::remove(m_Connections.begin(), m_Connections.end(), &slot), m_Connections.end());
+
+ m_MemoryStrategies.erase(m_MemoryStrategies.begin() + idx);
}
void OutputSlot::DisconnectAll()
@@ -100,6 +113,9 @@ void OutputSlot::MoveAllConnections(OutputSlot& destination)
{
while (GetNumConnections() > 0)
{
+ BOOST_ASSERT_MSG(m_MemoryStrategies[0] == MemoryStrategy::Undefined,
+ "Cannot move connections once memory strategies have be established.");
+
InputSlot& connection = *GetConnection(0);
Disconnect(connection);
destination.Connect(connection);
@@ -148,6 +164,26 @@ LayerGuid OutputSlot::GetOwningLayerGuid() const
return GetOwningLayer().GetGuid();
}
+void OutputSlot::SetTensorHandleFactory(const ITensorHandleFactory::FactoryId& id)
+{
+ m_TensorHandleFactoryId = id;
+}
+
+ITensorHandleFactory::FactoryId OutputSlot::GetTensorHandleFactoryId() const
+{
+ return m_TensorHandleFactoryId;
+}
+
+void OutputSlot::SetMemoryStrategy(unsigned int connectionIndex, MemoryStrategy strategy)
+{
+ m_MemoryStrategies[connectionIndex] = strategy;
+}
+
+MemoryStrategy OutputSlot::GetMemoryStrategyForConnection(unsigned int connectionIdx) const
+{
+ return m_MemoryStrategies[connectionIdx];
+}
+
namespace {
LayerGuid GenerateLayerGuid()
{
@@ -208,11 +244,26 @@ void Layer::CollectWorkloadOutputs(WorkloadDataCollector& dataCollector, const G
}
}
-void Layer::CreateTensorHandles(Graph& graph, const IWorkloadFactory& factory)
+void Layer::CreateTensorHandles(const TensorHandleFactoryRegistry& registry, const IWorkloadFactory& workloadFactory)
{
- for (auto&& outputHandler : m_OutputHandlers)
+ for (unsigned int idx=0; idx < GetNumOutputSlots(); idx++)
{
- outputHandler.CreateTensorHandles(factory);
+
+ OutputSlot& slot = GetOutputSlot(idx);
+ ITensorHandleFactory::FactoryId factoryId = slot.GetTensorHandleFactoryId();
+
+ OutputHandler& handler = GetOutputHandler(idx);
+ if (factoryId == ITensorHandleFactory::LegacyFactoryId)
+ {
+ handler.CreateTensorHandles(workloadFactory);
+ }
+ else
+ {
+ ITensorHandleFactory* handleFactory = registry.GetFactory(factoryId);
+ BOOST_ASSERT(handleFactory);
+
+ handler.CreateTensorHandles(*handleFactory);
+ }
}
}