diff options
Diffstat (limited to 'src/armnn')
-rw-r--r-- | src/armnn/Graph.cpp | 10 | ||||
-rw-r--r-- | src/armnn/Layer.cpp | 15 | ||||
-rw-r--r-- | src/armnn/Layer.hpp | 8 | ||||
-rw-r--r-- | src/armnn/LoadedNetwork.cpp | 28 | ||||
-rw-r--r-- | src/armnn/Network.cpp | 22 | ||||
-rw-r--r-- | src/armnn/layers/ConcatLayer.cpp | 3 | ||||
-rw-r--r-- | src/armnn/layers/ReverseV2Layer.cpp | 2 | ||||
-rw-r--r-- | src/armnn/layers/SplitterLayer.cpp | 5 |
8 files changed, 78 insertions, 15 deletions
diff --git a/src/armnn/Graph.cpp b/src/armnn/Graph.cpp index fee1da4343..cf6f20f82b 100644 --- a/src/armnn/Graph.cpp +++ b/src/armnn/Graph.cpp @@ -508,12 +508,22 @@ void Graph::ReplaceSubgraphConnections(const SubgraphView& subgraph, const Subgr if (subgraphInputSlot->GetConnection()) { IOutputSlot* connectedOutputSlot = subgraphInputSlot->GetConnection(); + InputSlot* inputSlot = PolymorphicDowncast<InputSlot*>(subgraphInputSlot); + bool isOverridden = inputSlot->IsTensorInfoOverridden(); + ARMNN_ASSERT(connectedOutputSlot); connectedOutputSlot->Disconnect(*subgraphInputSlot); IInputSlot* substituteInputSlot = substituteSubgraphInputSlots.at(inputSlotIdx); ARMNN_ASSERT(substituteInputSlot); connectedOutputSlot->Connect(*substituteInputSlot); + + if (isOverridden) + { + TensorInfo overridden = inputSlot->GetTensorInfo(); + InputSlot* newInputSlot = PolymorphicDowncast<InputSlot*>(substituteInputSlot); + newInputSlot->SetTensorInfo(overridden); + } } } diff --git a/src/armnn/Layer.cpp b/src/armnn/Layer.cpp index 8d4811ae89..d2f8f2c982 100644 --- a/src/armnn/Layer.cpp +++ b/src/armnn/Layer.cpp @@ -259,7 +259,20 @@ void Layer::CollectWorkloadInputs(WorkloadDataCollector& dataCollector) const // The graph must be well-formed at this point. ARMNN_ASSERT(inputSlot.GetConnection()); const OutputHandler& outputHandler = inputSlot.GetConnectedOutputSlot()->GetOutputHandler(); - dataCollector.Push(outputHandler.GetData(), outputHandler.GetTensorInfo()); + + if (inputSlot.IsTensorInfoOverridden() && outputHandler.GetData()) + { + auto handler = outputHandler.GetData()->DecorateTensorHandle(inputSlot.GetTensorInfo()); + + if (handler) + { + // Add overridden TensorHandle + dataCollector.Push(handler.get(), inputSlot.GetTensorInfo()); + continue; + } + } + // Add default TensorHandle + dataCollector.Push(outputHandler.GetData(), inputSlot.GetTensorInfo()); } } diff --git a/src/armnn/Layer.hpp b/src/armnn/Layer.hpp index 5e097f0fad..4f69e78b62 100644 --- a/src/armnn/Layer.hpp +++ b/src/armnn/Layer.hpp @@ -80,15 +80,15 @@ public: /// Sets the TensorInfo for this InputSlot. This can be used to override the TensorInfo and if set will be returned /// instead of the TensorInfo for the Connected OutputSlot. - void SetTensorInfo(const TensorInfo tensorInfo); + void SetTensorInfo(const TensorInfo tensorInfo) override; /// Gets the TensorInfo for this InputSlot. If the InputSlot's TensorInfo has not been set then this will get the /// TensorInfo from the Connected TensorInfo. - const TensorInfo& GetTensorInfo() const; + const TensorInfo& GetTensorInfo() const override; /// Returns true if this InputSlot either has an overridden TensorInfo for this InputSlot that was set through a /// call to SetTensorInfo() or is Connected to an OutputSlot that has its TensorInfo set. - bool IsTensorInfoSet() const; + bool IsTensorInfoSet() const override; /// Returns true if this InputSlot has an overridden TensorInfo that was set through a call to SetTensorInfo(). - bool IsTensorInfoOverridden() const; + bool IsTensorInfoOverridden() const override; private: Layer& m_OwningLayer; diff --git a/src/armnn/LoadedNetwork.cpp b/src/armnn/LoadedNetwork.cpp index 3f4aa34a5b..3d84054b69 100644 --- a/src/armnn/LoadedNetwork.cpp +++ b/src/armnn/LoadedNetwork.cpp @@ -955,10 +955,10 @@ Status LoadedNetwork::EnqueueWorkload(const InputTensors& inputTensors, syncDesc.m_Inputs.push_back(inputTensorHandle); WorkloadInfo info; info.m_InputTensorInfos.push_back( - outputLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo()); + outputLayer->GetInputSlot(0).GetTensorInfo()); auto syncWorkload = std::make_unique<SyncMemGenericWorkload>(syncDesc, info); ARMNN_ASSERT_MSG(syncWorkload, "No sync workload created"); - m_OutputQueue.push_back(move(syncWorkload)); + m_OutputQueue.push_back(std::move(syncWorkload)); importedOutputIdIndex++; } else @@ -1089,7 +1089,7 @@ void LoadedNetwork::EnqueueInput(const BindableLayer& layer, ITensorHandle* tens timelineUtils->Commit(); } - m_InputQueue.push_back(move(inputWorkload)); + m_InputQueue.push_back(std::move(inputWorkload)); } } @@ -1149,7 +1149,7 @@ void LoadedNetwork::EnqueueOutput(const BindableLayer& layer, ITensorHandle* ten info.m_InputTensorInfos.push_back(inputTensorInfo); auto syncWorkload = std::make_unique<SyncMemGenericWorkload>(syncDesc, info); ARMNN_ASSERT_MSG(syncWorkload, "No sync workload created"); - m_OutputQueue.push_back(move(syncWorkload)); + m_OutputQueue.push_back(std::move(syncWorkload)); } else { @@ -1177,7 +1177,7 @@ void LoadedNetwork::EnqueueOutput(const BindableLayer& layer, ITensorHandle* ten timelineUtils->Commit(); } - m_OutputQueue.push_back(move(outputWorkload)); + m_OutputQueue.push_back(std::move(outputWorkload)); } } @@ -1650,7 +1650,7 @@ std::vector<ImportedOutputId> LoadedNetwork::ImportOutputs(const OutputTensors& const InputSlot& inputSlot = layer->GetInputSlots()[0]; ITensorHandleFactory::FactoryId factoryId = inputSlot.GetConnectedOutputSlot()->GetTensorHandleFactoryId(); - const TensorInfo& tensorInfo = inputSlot.GetConnectedOutputSlot()->GetTensorInfo(); + const TensorInfo& tensorInfo = inputSlot.GetTensorInfo(); ITensorHandleFactory* handleFactory = m_TensorHandleFactoryRegistry.GetFactory(factoryId); ARMNN_ASSERT(handleFactory); @@ -2093,6 +2093,14 @@ std::unique_ptr<IWorkingMemHandle> LoadedNetwork::CreateWorkingMemHandle(Network if (found != m_ConstantTensorHandles.end()) { ITensorHandle* tensorHandle = found->second; + if (slot.IsTensorInfoOverridden()) + { + ITensorHandle* decorated = tensorHandle->DecorateTensorHandle(slot.GetTensorInfo()).get(); + if (decorated) + { + tensorHandle = decorated; + } + } workingMemDescriptor.m_Inputs.push_back(tensorHandle); // Odd case where a constant layer is connected to an output layer @@ -2113,6 +2121,14 @@ std::unique_ptr<IWorkingMemHandle> LoadedNetwork::CreateWorkingMemHandle(Network HandleInfo& handleInfo = outputToHandleInfoMap.at(outputSlot); ITensorHandle* inputTensorHandle = handleInfo.m_TensorHandle; + if (slot.IsTensorInfoOverridden()) + { + ITensorHandle* decorated = inputTensorHandle->DecorateTensorHandle(slot.GetTensorInfo()).get(); + if (decorated) + { + inputTensorHandle = decorated; + } + } workingMemDescriptor.m_Inputs.push_back(inputTensorHandle); // Store the LayerBindingId of the OutputLayer diff --git a/src/armnn/Network.cpp b/src/armnn/Network.cpp index 2abaf44587..ae5bde17ca 100644 --- a/src/armnn/Network.cpp +++ b/src/armnn/Network.cpp @@ -1310,6 +1310,28 @@ OptimizationResult ApplyBackendOptimizations(OptimizedNetworkImpl* optNetObjPtr, }); } + // Remove deleted sub-graphs + for (auto& deletedSubgraph : optimizationViews.GetDeletedSubgraphs()) + { + for (auto& l : deletedSubgraph.GetIConnectableLayers()) + { + Layer* deletedLayer = PolymorphicDowncast<Layer*>(l); + for (unsigned int in = deletedLayer->GetNumInputSlots(); in > 0; --in) + { + auto inputSlot = deletedLayer->GetInputSlot(in -1); + OutputSlot* parentOut = inputSlot.GetConnectedOutputSlot(); + parentOut->Disconnect(inputSlot); + for (unsigned int out = deletedLayer->GetOutputSlot(in -1).GetNumConnections(); out > 0; --out) + { + InputSlot *childIn = deletedLayer->GetOutputSlot(in - 1).GetConnection(out -1); + deletedLayer->GetOutputSlot(in - 1).Disconnect(*childIn); + parentOut->Connect(*childIn); + } + } + optGraph.EraseLayer(deletedLayer); + } + } + if (!optimizationViews.GetFailedSubgraphs().empty()) { std::stringstream warningMsg; diff --git a/src/armnn/layers/ConcatLayer.cpp b/src/armnn/layers/ConcatLayer.cpp index 7a1b689b2c..4629bf245e 100644 --- a/src/armnn/layers/ConcatLayer.cpp +++ b/src/armnn/layers/ConcatLayer.cpp @@ -120,7 +120,7 @@ void ConcatLayer::CreateTensors(const TensorHandleFactoryRegistry& registry, // 3) the input does not come from a Constant layer or input layer // 4) the input is only read by this concat layer // 5) if concat along x or y (2 innermost dimensions) and the previous layers do not require padding - // 6) the input does not have an Overridden TensorInfo + // 6) neither the inputs nor the output have an Overridden TensorInfo if (slot && parentInfo.IsTypeSpaceMatch(info) && //(1) factoryId == slot->GetTensorHandleFactoryId() && //(2) @@ -128,6 +128,7 @@ void ConcatLayer::CreateTensors(const TensorHandleFactoryRegistry& registry, slot->GetOwningLayer().GetType() != LayerType::Input && //(3) slot->GetNumConnections() == 1 && canUseSubTensorOnXorY && //(5) + !GetOutputSlot(0).GetConnection(0)->IsTensorInfoOverridden() && //(6) !currentLayer->GetInputSlot(i).IsTensorInfoOverridden()) //(6) { ARMNN_NO_DEPRECATE_WARN_BEGIN diff --git a/src/armnn/layers/ReverseV2Layer.cpp b/src/armnn/layers/ReverseV2Layer.cpp index 29f8b1b781..201e19819b 100644 --- a/src/armnn/layers/ReverseV2Layer.cpp +++ b/src/armnn/layers/ReverseV2Layer.cpp @@ -40,7 +40,7 @@ void ReverseV2Layer::ValidateTensorShapesFromInputs() VerifyShapeInferenceType(outputShape, m_ShapeInferenceMethod); auto inferredShapes = InferOutputShapes({ - GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape() }); + GetInputSlot(0).GetTensorInfo().GetShape() }); ARMNN_ASSERT(inferredShapes.size() == 1); diff --git a/src/armnn/layers/SplitterLayer.cpp b/src/armnn/layers/SplitterLayer.cpp index 86a42305ff..dc8864a736 100644 --- a/src/armnn/layers/SplitterLayer.cpp +++ b/src/armnn/layers/SplitterLayer.cpp @@ -131,13 +131,14 @@ void SplitterLayer::CreateTensors(const TensorHandleFactoryRegistry& registry, // 2) the same TensorHandleFactory is used for input and split layer output // 3) the output does not go to a Constant layer or input layer // 4) if split along x or y (2 innermost dimensions) and the next layers do not require padding - // 5) none of the outputs have an Overridden TensorInfo + // 5) neither the input nor the outputs have an Overridden TensorInfo if (parentInfo.IsTypeSpaceMatch(info) && //(1) factoryId == slot->GetTensorHandleFactoryId() && //(2) GetOutputSlot(i).GetConnection(0)->GetOwningLayer().GetType() != LayerType::Constant && //(3) GetOutputSlot(i).GetConnection(0)->GetOwningLayer().GetType() != LayerType::Input && //(3) canUseSubTensorOnXorY && //(4) - !GetOutputSlot(i).GetConnection(0)->IsTensorInfoOverridden()) //(5) + !GetOutputSlot(i).GetConnection(0)->IsTensorInfoOverridden() && //(5) + !GetInputSlot(0).IsTensorInfoOverridden()) //(5) { ARMNN_NO_DEPRECATE_WARN_BEGIN return factory.CreateSubTensorHandle(*inputData, |