aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/layers/MergerLayer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/layers/MergerLayer.cpp')
-rw-r--r--src/armnn/layers/MergerLayer.cpp56
1 files changed, 48 insertions, 8 deletions
diff --git a/src/armnn/layers/MergerLayer.cpp b/src/armnn/layers/MergerLayer.cpp
index f87f34925f..c674f64f3f 100644
--- a/src/armnn/layers/MergerLayer.cpp
+++ b/src/armnn/layers/MergerLayer.cpp
@@ -36,14 +36,12 @@ std::unique_ptr<IWorkload> MergerLayer::CreateWorkload(const Graph& graph, const
void MergerLayer::CreateTensorHandles(Graph& graph, const IWorkloadFactory& factory)
{
- //If sub tensors are supported than the merger
+ //If sub tensors are supported then the merger
//just needs to make sure that the outputs of the prev layer
//are made subtensors of the output of the merger layer.
m_OutputHandlers[0].CreateTensorHandles(factory);
- unsigned int innerAxis = m_Param.GetNumDimensions() - m_Param.GetConcatAxis();
-
- if (factory.SupportsSubTensors() && innerAxis != 1)
+ if (factory.SupportsSubTensors())
{
std::queue<MergerLayer*> m_MergerLayers;
@@ -52,23 +50,65 @@ void MergerLayer::CreateTensorHandles(Graph& graph, const IWorkloadFactory& fact
{
MergerLayer* currentLayer = m_MergerLayers.front();
ITensorHandle* parentTensor = currentLayer->GetOutputHandler(0).GetData();
-
+ const TensorInfo& parentInfo = currentLayer->GetOutputHandler(0).GetTensorInfo();
m_MergerLayers.pop();
const unsigned int numInputSlots = currentLayer->GetNumInputSlots();
+
+ // First go through all the input slots and verify that we can sub-tensor all the inputs.
+ std::vector<std::unique_ptr<ITensorHandle>> subTensors(0);
+ subTensors.reserve(numInputSlots);
for (unsigned int i = 0; i < numInputSlots; ++i)
{
OutputSlot* slot = currentLayer->GetInputSlot(i).GetConnectedOutputSlot();
+ const TensorInfo& info = slot->GetTensorInfo();
+
+ auto CreateSubTensor = [&]()
+ {
+ // Make sure quantization parameters are in the same space
+ if (parentInfo.IsTypeSpaceMatch(info))
+ {
+ return factory.CreateSubTensorHandle(*parentTensor,
+ info.GetShape(),
+ currentLayer->m_Param.GetViewOrigin(i));
+ }
+ return std::unique_ptr<ITensorHandle>();
+ };
+
+ auto subTensor = CreateSubTensor();
+ if (!subTensor)
+ {
+ break; //Failed to create a valid sub-tensor, so stop trying with the rest of the inputs.
+ }
+ else
+ {
+ subTensors.push_back(std::move(subTensor)); // store the valid sub-tensor.
+ }
+ }
+
+ // Ensure that ALL inputs can be substituted with valid sub-tensors
+ if (subTensors.size() < numInputSlots)
+ {
+ continue; // Don't optimize this Merge layer with sub-tensors
+ }
+
+ // Substitute input tensors with sub-tensors by replacing the output tensors on the connected layers.
+ unsigned int i=0;
+ for (auto& subTensor : subTensors)
+ {
+ OutputSlot* slot = currentLayer->GetInputSlot(i).GetConnectedOutputSlot();
OutputHandler& outputHandler = slot->GetOutputHandler();
- outputHandler.SetData(factory.CreateSubTensorHandle(*parentTensor,
- outputHandler.GetTensorInfo().GetShape(),
- currentLayer->m_Param.GetViewOrigin(i)));
+
+ BOOST_ASSERT_MSG(subTensor, "MergerLayer: Expected a valid sub-tensor for substitution.");
+ outputHandler.SetData(std::move(subTensor));
Layer& inputLayer = slot->GetOwningLayer();
if (inputLayer.GetType() == LayerType::Merger)
{
+ // Continue with the substitution if the connected inputs are also merger layers
m_MergerLayers.push(boost::polymorphic_downcast<MergerLayer*>(&inputLayer));
}
+ ++i;
}
}
}