diff options
Diffstat (limited to 'src/armnn/layers/SplitterLayer.cpp')
-rw-r--r-- | src/armnn/layers/SplitterLayer.cpp | 27 |
1 files changed, 25 insertions, 2 deletions
diff --git a/src/armnn/layers/SplitterLayer.cpp b/src/armnn/layers/SplitterLayer.cpp index 4a6b2220a7..dc04b3fd15 100644 --- a/src/armnn/layers/SplitterLayer.cpp +++ b/src/armnn/layers/SplitterLayer.cpp @@ -32,7 +32,8 @@ std::unique_ptr<IWorkload> SplitterLayer::CreateWorkload(const Graph& graph, con return factory.CreateSplitter(descriptor, PrepInfoAndDesc(descriptor, graph)); } -void SplitterLayer::CreateTensorHandles(Graph& graph, const IWorkloadFactory& factory) +template<typename FactoryType> +void SplitterLayer::CreateTensors(const FactoryType& factory) { //If sub tensors are supported than all the "splitter" need to do is to //set the outputs to be appropriate sub tensors of the input. @@ -40,6 +41,7 @@ void SplitterLayer::CreateTensorHandles(Graph& graph, const IWorkloadFactory& fa if (useSubTensors) { + const OutputSlot* slot = GetInputSlots()[0].GetConnectedOutputSlot(); const OutputHandler& outputHandler = GetInputSlots()[0].GetConnectedOutputSlot()->GetOutputHandler(); const TensorInfo& parentInfo = outputHandler.GetTensorInfo(); @@ -53,10 +55,13 @@ void SplitterLayer::CreateTensorHandles(Graph& graph, const IWorkloadFactory& fa { const TensorInfo& info = m_OutputHandlers[i].GetTensorInfo(); + OutputSlot& outSlot = GetOutputSlot(i); + ITensorHandleFactory::FactoryId factoryId = outSlot.GetTensorHandleFactoryId(); auto CreateSubTensor = [&]() { // Make sure quantization parameters are in the same space - if (parentInfo.IsTypeSpaceMatch(info)) + if (parentInfo.IsTypeSpaceMatch(info) && + factoryId == slot->GetTensorHandleFactoryId()) { return factory.CreateSubTensorHandle(*inputData, info.GetShape(), @@ -95,6 +100,24 @@ void SplitterLayer::CreateTensorHandles(Graph& graph, const IWorkloadFactory& fa } } +void SplitterLayer::CreateTensorHandles(const TensorHandleFactoryRegistry& registry, + const IWorkloadFactory& workloadFactory) +{ + OutputSlot& slot = GetOutputSlot(0); + ITensorHandleFactory::FactoryId factoryId = slot.GetTensorHandleFactoryId(); + + if (factoryId == ITensorHandleFactory::LegacyFactoryId) + { + CreateTensors(workloadFactory); + } + else + { + ITensorHandleFactory* handleFactory = registry.GetFactory(factoryId); + BOOST_ASSERT(handleFactory); + CreateTensors(*handleFactory); + } +} + SplitterLayer* SplitterLayer::Clone(Graph& graph) const { return CloneBase<SplitterLayer>(graph, m_Param, GetName()); |