aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/layers/SplitterLayer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/layers/SplitterLayer.cpp')
-rw-r--r--src/armnn/layers/SplitterLayer.cpp85
1 files changed, 77 insertions, 8 deletions
diff --git a/src/armnn/layers/SplitterLayer.cpp b/src/armnn/layers/SplitterLayer.cpp
index 2d469b0bc9..72f27f7f01 100644
--- a/src/armnn/layers/SplitterLayer.cpp
+++ b/src/armnn/layers/SplitterLayer.cpp
@@ -33,7 +33,7 @@ std::unique_ptr<IWorkload> SplitterLayer::CreateWorkload(const IWorkloadFactory&
}
template<typename FactoryType>
-void SplitterLayer::CreateTensors(const FactoryType& factory)
+void SplitterLayer::CreateTensors(const TensorHandleFactoryRegistry& registry, const FactoryType& factory)
{
//If sub tensors are supported than all the "splitter" need to do is to
//set the outputs to be appropriate sub tensors of the input.
@@ -41,8 +41,9 @@ void SplitterLayer::CreateTensors(const FactoryType& factory)
if (useSubTensors)
{
- const OutputSlot* slot = GetInputSlots()[0].GetConnectedOutputSlot();
+ // Get outputHandler of previous layer
const OutputHandler& outputHandler = GetInputSlots()[0].GetConnectedOutputSlot()->GetOutputHandler();
+ const OutputSlot* slot = GetInputSlots()[0].GetConnectedOutputSlot();
const TensorInfo& parentInfo = outputHandler.GetTensorInfo();
@@ -50,6 +51,36 @@ void SplitterLayer::CreateTensors(const FactoryType& factory)
std::vector<std::unique_ptr<ITensorHandle>> subTensors;
+ // check if split is along the x or y (2 innermost dimensions)
+ auto numberOfDimensions = m_Param.GetNumDimensions();
+
+ // Compute split axis within class as aclCommon function causes header issues when included
+ auto ComputeSplitAxis = [&](const armnn::SplitterDescriptor& desc, const TensorShape& input)
+ {
+ unsigned int numSplit = desc.GetNumViews();
+ unsigned int numDimensions = desc.GetNumDimensions();
+ std::set<unsigned int> splitAxis;
+
+ for (unsigned int i = 0; i < numSplit; ++i)
+ {
+ for (unsigned int dimIdx = 0; dimIdx < numDimensions; ++dimIdx)
+ {
+ if (desc.GetViewSizes(i)[dimIdx] != input[dimIdx])
+ {
+ splitAxis.insert(dimIdx);
+ }
+ }
+ }
+ return splitAxis;
+ };
+
+ std::set<unsigned int> axis = ComputeSplitAxis(m_Param, parentInfo.GetShape());
+ std::set<unsigned int>::iterator axisIt = axis.begin();
+
+ bool isOnXorY = m_Param.GetNumDimensions() >= 3 &&
+ ((*axisIt == numberOfDimensions - 1) ||
+ (*axisIt == numberOfDimensions - 2));
+
//Creates the outputs as subtensors of the input.
for (unsigned int i = 0; i < m_Param.GetNumViews(); ++i)
{
@@ -57,11 +88,50 @@ void SplitterLayer::CreateTensors(const FactoryType& factory)
OutputSlot& outSlot = GetOutputSlot(i);
ITensorHandleFactory::FactoryId factoryId = outSlot.GetTensorHandleFactoryId();
+
+ const unsigned int numOutputSlots = GetNumOutputSlots();
+
+ // if split along x or y (2 innermost dimensions) and the next layers do not require padding
+ bool canUseSubTensorOnXorY = true;
+ bool isTensorHandleFactory = std::is_same<armnn::ITensorHandleFactory, FactoryType>::value;
+ if (isTensorHandleFactory)
+ {
+ for (unsigned int it = 0; it < numOutputSlots; ++it)
+ {
+ InputSlot* inputSlot = GetOutputSlot(it).GetConnection(0);
+ ITensorHandleFactory* handleFactory = registry.GetFactory(factoryId);
+ std::vector<Capability> capabilities =
+ handleFactory->GetCapabilities(&(inputSlot->GetOwningLayer()),
+ this,
+ CapabilityClass::PaddingRequired);
+ if (isOnXorY)
+ {
+ canUseSubTensorOnXorY = false;
+ if (capabilities.empty())
+ {
+ canUseSubTensorOnXorY = true;
+ }
+ }
+
+ if (!canUseSubTensorOnXorY)
+ {
+ break;
+ }
+ }
+ }
+
auto CreateSubTensor = [&]()
{
- // Make sure quantization parameters are in the same space
- if (parentInfo.IsTypeSpaceMatch(info) &&
- factoryId == slot->GetTensorHandleFactoryId())
+ // Make sure:
+ // 1) quantization parameters are in the same space
+ // 2) the same TensorHandleFactory is used for input and split layer output
+ // 3) the output does not go to a Constant layer or input layer
+ // 4) if split along x or y (2 innermost dimensions) and the next layers do not require padding
+ if (parentInfo.IsTypeSpaceMatch(info) && //(1)
+ factoryId == slot->GetTensorHandleFactoryId() && //(2)
+ GetOutputSlot(i).GetConnection(0)->GetOwningLayer().GetType() != LayerType::Constant && //(3)
+ GetOutputSlot(i).GetConnection(0)->GetOwningLayer().GetType() != LayerType::Input && //(3)
+ canUseSubTensorOnXorY) //(4)
{
return factory.CreateSubTensorHandle(*inputData,
info.GetShape(),
@@ -87,7 +157,6 @@ void SplitterLayer::CreateTensors(const FactoryType& factory)
m_OutputHandlers[i].SetData(std::move(subTensor));
++i;
}
-
}
}
@@ -110,13 +179,13 @@ void SplitterLayer::CreateTensorHandles(const TensorHandleFactoryRegistry& regis
if (factoryId == ITensorHandleFactory::LegacyFactoryId)
{
- CreateTensors(workloadFactory);
+ CreateTensors(registry, workloadFactory);
}
else
{
ITensorHandleFactory* handleFactory = registry.GetFactory(factoryId);
ARMNN_ASSERT(handleFactory);
- CreateTensors(*handleFactory);
+ CreateTensors(registry, *handleFactory);
}
}