diff options
Diffstat (limited to 'src/backends/backendsCommon/SubgraphUtils.hpp')
-rw-r--r-- | src/backends/backendsCommon/SubgraphUtils.hpp | 41 |
1 files changed, 41 insertions, 0 deletions
diff --git a/src/backends/backendsCommon/SubgraphUtils.hpp b/src/backends/backendsCommon/SubgraphUtils.hpp index ade4b63976..823da76f29 100644 --- a/src/backends/backendsCommon/SubgraphUtils.hpp +++ b/src/backends/backendsCommon/SubgraphUtils.hpp @@ -199,6 +199,47 @@ LayerType* FoldPadLayer(OptimizationViews& optimizationViews, return replacementLayer; } +/// Checks if the Layer is connected to any Layer that has an NCHW layout. +inline bool ConnectedToLayerWithNCHW(Layer* baseLayer) +{ + Layer& parentLayer = baseLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetOwningLayer(); + + if (IsNCHW(parentLayer)) + { + return true; + } + for (unsigned int i = 0; i < baseLayer->GetOutputSlot(0).GetNumConnections(); ++i) + { + Layer& nextLayer = baseLayer->GetOutputSlot(0).GetConnection(i)->GetOwningLayer(); + if (IsNCHW(nextLayer)) + { + return true; + } + } + return false; +} + +/// Checks if the Layer is connected to a Splitter Layer through a Tensor that has more than 4 dimensions. +inline bool ConnectedToSplitterWithMoreThan4Dims(Layer* baseLayer) +{ + Layer& parentLayer = baseLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetOwningLayer(); + TensorInfo parentTensorInfo = baseLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo(); + if (parentTensorInfo.GetNumDimensions() > 4 && parentLayer.GetType() == LayerType::Splitter) + { + return true; + } + for (unsigned int i = 0; i < baseLayer->GetOutputSlot(0).GetNumConnections(); ++i) + { + Layer& nextLayer = baseLayer->GetOutputSlot(0).GetConnection(i)->GetOwningLayer(); + TensorInfo nextTensorInfo = baseLayer->GetOutputSlot(0).GetConnection(i)->GetTensorInfo(); + if (nextTensorInfo.GetNumDimensions() > 4 && nextLayer.GetType() == LayerType::Splitter) + { + return true; + } + } + return false; +} + inline void RemoveReshapeLayer(ReshapeLayer* baseLayer, std::map<LayerGuid, Layer*>& untouched, OptimizationViews& optimizationViews) |