diff options
Diffstat (limited to 'src/armnn/Layer.cpp')
-rw-r--r-- | src/armnn/Layer.cpp | 59 |
1 files changed, 55 insertions, 4 deletions
diff --git a/src/armnn/Layer.cpp b/src/armnn/Layer.cpp index ced87b095c..a287220702 100644 --- a/src/armnn/Layer.cpp +++ b/src/armnn/Layer.cpp @@ -30,7 +30,8 @@ void InputSlot::Insert(Layer& layer) // Connects inserted layer to parent. BOOST_ASSERT(layer.GetNumInputSlots() == 1); - prevSlot->Connect(layer.GetInputSlot(0)); + int idx = prevSlot->Connect(layer.GetInputSlot(0)); + prevSlot->SetMemoryStrategy(boost::numeric_cast<unsigned int>(idx), MemoryStrategy::Undefined); // Sets tensor info for inserted layer. const TensorInfo& tensorInfo = prevSlot->GetTensorInfo(); @@ -39,6 +40,7 @@ void InputSlot::Insert(Layer& layer) // Connects inserted layer to this. layer.GetOutputSlot(0).Connect(*this); + layer.GetOutputSlot(0).SetMemoryStrategy(0, MemoryStrategy::Undefined); } const InputSlot* OutputSlot::GetConnection(unsigned int index) const @@ -78,13 +80,24 @@ int OutputSlot::Connect(InputSlot& destination) { destination.SetConnection(this); m_Connections.push_back(&destination); + m_MemoryStrategies.push_back(MemoryStrategy::Undefined); return boost::numeric_cast<int>(m_Connections.size() - 1); } void OutputSlot::Disconnect(InputSlot& slot) { slot.SetConnection(nullptr); + auto it = std::find(m_Connections.begin(), m_Connections.end(), &slot); + + if (it == m_Connections.end()) + { + return; + } + + auto idx = std::distance(m_Connections.begin(), it); m_Connections.erase(std::remove(m_Connections.begin(), m_Connections.end(), &slot), m_Connections.end()); + + m_MemoryStrategies.erase(m_MemoryStrategies.begin() + idx); } void OutputSlot::DisconnectAll() @@ -100,6 +113,9 @@ void OutputSlot::MoveAllConnections(OutputSlot& destination) { while (GetNumConnections() > 0) { + BOOST_ASSERT_MSG(m_MemoryStrategies[0] == MemoryStrategy::Undefined, + "Cannot move connections once memory strategies have be established."); + InputSlot& connection = *GetConnection(0); Disconnect(connection); destination.Connect(connection); @@ -148,6 +164,26 @@ LayerGuid OutputSlot::GetOwningLayerGuid() const return GetOwningLayer().GetGuid(); } +void OutputSlot::SetTensorHandleFactory(const ITensorHandleFactory::FactoryId& id) +{ + m_TensorHandleFactoryId = id; +} + +ITensorHandleFactory::FactoryId OutputSlot::GetTensorHandleFactoryId() const +{ + return m_TensorHandleFactoryId; +} + +void OutputSlot::SetMemoryStrategy(unsigned int connectionIndex, MemoryStrategy strategy) +{ + m_MemoryStrategies[connectionIndex] = strategy; +} + +MemoryStrategy OutputSlot::GetMemoryStrategyForConnection(unsigned int connectionIdx) const +{ + return m_MemoryStrategies[connectionIdx]; +} + namespace { LayerGuid GenerateLayerGuid() { @@ -208,11 +244,26 @@ void Layer::CollectWorkloadOutputs(WorkloadDataCollector& dataCollector, const G } } -void Layer::CreateTensorHandles(Graph& graph, const IWorkloadFactory& factory) +void Layer::CreateTensorHandles(const TensorHandleFactoryRegistry& registry, const IWorkloadFactory& workloadFactory) { - for (auto&& outputHandler : m_OutputHandlers) + for (unsigned int idx=0; idx < GetNumOutputSlots(); idx++) { - outputHandler.CreateTensorHandles(factory); + + OutputSlot& slot = GetOutputSlot(idx); + ITensorHandleFactory::FactoryId factoryId = slot.GetTensorHandleFactoryId(); + + OutputHandler& handler = GetOutputHandler(idx); + if (factoryId == ITensorHandleFactory::LegacyFactoryId) + { + handler.CreateTensorHandles(workloadFactory); + } + else + { + ITensorHandleFactory* handleFactory = registry.GetFactory(factoryId); + BOOST_ASSERT(handleFactory); + + handler.CreateTensorHandles(*handleFactory); + } } } |