diff options
author | Sadik Armagan <sadik.armagan@arm.com> | 2020-08-04 14:01:05 +0100 |
---|---|---|
committer | KeithARM <keith.davis@arm.com> | 2020-08-07 12:44:19 +0000 |
commit | 76615a5edd55b890acdd5fb078d9242e1e719a45 (patch) | |
tree | 5c75480b2f7594fa8934952d0037da543026d58c /src/armnn | |
parent | e88167264991b8debe56f095abbc262f7266c5d4 (diff) | |
download | armnn-76615a5edd55b890acdd5fb078d9242e1e719a45.tar.gz |
IVGCVSW-5108 Allow Concat to use subtensor on x and y
* Updated ConcatLayer to allow using subtensors on x/y if padding is not required
Signed-off-by: Sadik Armagan <sadik.armagan@arm.com>
Change-Id: I46a8fb9f17b976b76e069bb82614b6628a206717
Diffstat (limited to 'src/armnn')
-rw-r--r-- | src/armnn/layers/ConcatLayer.cpp | 45 | ||||
-rw-r--r-- | src/armnn/layers/ConcatLayer.hpp | 2 |
2 files changed, 42 insertions, 5 deletions
diff --git a/src/armnn/layers/ConcatLayer.cpp b/src/armnn/layers/ConcatLayer.cpp index d9fffff57e..fac6a1f197 100644 --- a/src/armnn/layers/ConcatLayer.cpp +++ b/src/armnn/layers/ConcatLayer.cpp @@ -36,7 +36,7 @@ std::unique_ptr<IWorkload> ConcatLayer::CreateWorkload(const IWorkloadFactory& f } template<typename FactoryType> -void ConcatLayer::CreateTensors(const FactoryType& factory) +void ConcatLayer::CreateTensors(const TensorHandleFactoryRegistry& registry, const FactoryType& factory) { //If sub tensors are supported then the concat //just needs to make sure that the outputs of the prev layer @@ -45,6 +45,12 @@ void ConcatLayer::CreateTensors(const FactoryType& factory) if (factory.SupportsSubTensors()) { + // check if concat is along the x or y (2 innermost dimensions) + uint32_t concatAxis = m_Param.GetConcatAxis(); + auto numberOfDimensions = m_Param.GetNumDimensions(); + bool isConcatOnXorY = m_Param.GetNumDimensions() >= 3 + && ((concatAxis == numberOfDimensions - 1) || (concatAxis == numberOfDimensions - 2)); + ITensorHandleFactory::FactoryId factoryId = GetOutputSlot(0).GetTensorHandleFactoryId(); std::queue<ConcatLayer*> m_ConcatLayers; @@ -59,6 +65,35 @@ void ConcatLayer::CreateTensors(const FactoryType& factory) const unsigned int numInputSlots = currentLayer->GetNumInputSlots(); + // if concat along x or y (2 innermost dimensions) and the previous layers do not require padding + bool canUseSubTensorOnXorY = true; + bool isTensorHandleFactory = std::is_same<armnn::ITensorHandleFactory, FactoryType>::value; + if (isTensorHandleFactory) + { + for (unsigned int i = 0; i < numInputSlots; ++i) + { + OutputSlot* slot = currentLayer->GetInputSlot(i).GetConnectedOutputSlot(); + ITensorHandleFactory* handleFactory = registry.GetFactory(factoryId); + std::vector<Capability> capabilities = + handleFactory->GetCapabilities(&(slot->GetOwningLayer()), + currentLayer, + CapabilityClass::PaddingRequired); + if (isConcatOnXorY) + { + canUseSubTensorOnXorY = false; + if (capabilities.empty()) + { + canUseSubTensorOnXorY = true; + } + } + + if (!canUseSubTensorOnXorY) + { + break; + } + } + } + // First go through all the input slots and verify that we can sub-tensor all the inputs. std::vector<std::unique_ptr<ITensorHandle>> subTensors(0); subTensors.reserve(numInputSlots); @@ -74,12 +109,14 @@ void ConcatLayer::CreateTensors(const FactoryType& factory) // 2) the same TensorHandleFactory is used for input and Concat layer output // 3) the input does not come from a Constant layer or input layer // 4) the input is only read by this concat layer + // 5) if concat along x or y (2 innermost dimensions) and the previous layers do not require padding if (slot && parentInfo.IsTypeSpaceMatch(info) && //(1) factoryId == slot->GetTensorHandleFactoryId() && //(2) slot->GetOwningLayer().GetType() != LayerType::Constant && //(3) slot->GetOwningLayer().GetType() != LayerType::Input && //(3) - slot->GetNumConnections() == 1) //(4) + slot->GetNumConnections() == 1 && + canUseSubTensorOnXorY) //(5) { return factory.CreateSubTensorHandle(*parentTensor, info.GetShape(), @@ -137,13 +174,13 @@ void ConcatLayer::CreateTensorHandles(const TensorHandleFactoryRegistry& registr if (factoryId == ITensorHandleFactory::LegacyFactoryId) { - CreateTensors(workloadFactory); + CreateTensors(registry, workloadFactory); } else { ITensorHandleFactory* handleFactory = registry.GetFactory(factoryId); ARMNN_ASSERT(handleFactory); - CreateTensors(*handleFactory); + CreateTensors(registry, *handleFactory); } } diff --git a/src/armnn/layers/ConcatLayer.hpp b/src/armnn/layers/ConcatLayer.hpp index 84eba2e7c9..eaa5c15a9c 100644 --- a/src/armnn/layers/ConcatLayer.hpp +++ b/src/armnn/layers/ConcatLayer.hpp @@ -56,7 +56,7 @@ protected: private: template <typename FactoryType> - void CreateTensors(const FactoryType& factory); + void CreateTensors(const TensorHandleFactoryRegistry& registry, const FactoryType& factory); }; |