aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/layers/ConcatLayer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/layers/ConcatLayer.cpp')
-rw-r--r--src/armnn/layers/ConcatLayer.cpp45
1 files changed, 41 insertions, 4 deletions
diff --git a/src/armnn/layers/ConcatLayer.cpp b/src/armnn/layers/ConcatLayer.cpp
index d9fffff57e..fac6a1f197 100644
--- a/src/armnn/layers/ConcatLayer.cpp
+++ b/src/armnn/layers/ConcatLayer.cpp
@@ -36,7 +36,7 @@ std::unique_ptr<IWorkload> ConcatLayer::CreateWorkload(const IWorkloadFactory& f
}
template<typename FactoryType>
-void ConcatLayer::CreateTensors(const FactoryType& factory)
+void ConcatLayer::CreateTensors(const TensorHandleFactoryRegistry& registry, const FactoryType& factory)
{
//If sub tensors are supported then the concat
//just needs to make sure that the outputs of the prev layer
@@ -45,6 +45,12 @@ void ConcatLayer::CreateTensors(const FactoryType& factory)
if (factory.SupportsSubTensors())
{
+ // check if concat is along the x or y (2 innermost dimensions)
+ uint32_t concatAxis = m_Param.GetConcatAxis();
+ auto numberOfDimensions = m_Param.GetNumDimensions();
+ bool isConcatOnXorY = m_Param.GetNumDimensions() >= 3
+ && ((concatAxis == numberOfDimensions - 1) || (concatAxis == numberOfDimensions - 2));
+
ITensorHandleFactory::FactoryId factoryId = GetOutputSlot(0).GetTensorHandleFactoryId();
std::queue<ConcatLayer*> m_ConcatLayers;
@@ -59,6 +65,35 @@ void ConcatLayer::CreateTensors(const FactoryType& factory)
const unsigned int numInputSlots = currentLayer->GetNumInputSlots();
+ // if concat along x or y (2 innermost dimensions) and the previous layers do not require padding
+ bool canUseSubTensorOnXorY = true;
+ bool isTensorHandleFactory = std::is_same<armnn::ITensorHandleFactory, FactoryType>::value;
+ if (isTensorHandleFactory)
+ {
+ for (unsigned int i = 0; i < numInputSlots; ++i)
+ {
+ OutputSlot* slot = currentLayer->GetInputSlot(i).GetConnectedOutputSlot();
+ ITensorHandleFactory* handleFactory = registry.GetFactory(factoryId);
+ std::vector<Capability> capabilities =
+ handleFactory->GetCapabilities(&(slot->GetOwningLayer()),
+ currentLayer,
+ CapabilityClass::PaddingRequired);
+ if (isConcatOnXorY)
+ {
+ canUseSubTensorOnXorY = false;
+ if (capabilities.empty())
+ {
+ canUseSubTensorOnXorY = true;
+ }
+ }
+
+ if (!canUseSubTensorOnXorY)
+ {
+ break;
+ }
+ }
+ }
+
// First go through all the input slots and verify that we can sub-tensor all the inputs.
std::vector<std::unique_ptr<ITensorHandle>> subTensors(0);
subTensors.reserve(numInputSlots);
@@ -74,12 +109,14 @@ void ConcatLayer::CreateTensors(const FactoryType& factory)
// 2) the same TensorHandleFactory is used for input and Concat layer output
// 3) the input does not come from a Constant layer or input layer
// 4) the input is only read by this concat layer
+ // 5) if concat along x or y (2 innermost dimensions) and the previous layers do not require padding
if (slot &&
parentInfo.IsTypeSpaceMatch(info) && //(1)
factoryId == slot->GetTensorHandleFactoryId() && //(2)
slot->GetOwningLayer().GetType() != LayerType::Constant && //(3)
slot->GetOwningLayer().GetType() != LayerType::Input && //(3)
- slot->GetNumConnections() == 1) //(4)
+ slot->GetNumConnections() == 1 &&
+ canUseSubTensorOnXorY) //(5)
{
return factory.CreateSubTensorHandle(*parentTensor,
info.GetShape(),
@@ -137,13 +174,13 @@ void ConcatLayer::CreateTensorHandles(const TensorHandleFactoryRegistry& registr
if (factoryId == ITensorHandleFactory::LegacyFactoryId)
{
- CreateTensors(workloadFactory);
+ CreateTensors(registry, workloadFactory);
}
else
{
ITensorHandleFactory* handleFactory = registry.GetFactory(factoryId);
ARMNN_ASSERT(handleFactory);
- CreateTensors(*handleFactory);
+ CreateTensors(registry, *handleFactory);
}
}