aboutsummaryrefslogtreecommitdiff
path: root/src/backends/backendsCommon
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/backendsCommon')
-rw-r--r--src/backends/backendsCommon/SubgraphUtils.hpp41
1 files changed, 41 insertions, 0 deletions
diff --git a/src/backends/backendsCommon/SubgraphUtils.hpp b/src/backends/backendsCommon/SubgraphUtils.hpp
index ade4b63976..823da76f29 100644
--- a/src/backends/backendsCommon/SubgraphUtils.hpp
+++ b/src/backends/backendsCommon/SubgraphUtils.hpp
@@ -199,6 +199,47 @@ LayerType* FoldPadLayer(OptimizationViews& optimizationViews,
return replacementLayer;
}
+/// Checks if the Layer is connected to any Layer that has an NCHW layout.
+inline bool ConnectedToLayerWithNCHW(Layer* baseLayer)
+{
+ Layer& parentLayer = baseLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetOwningLayer();
+
+ if (IsNCHW(parentLayer))
+ {
+ return true;
+ }
+ for (unsigned int i = 0; i < baseLayer->GetOutputSlot(0).GetNumConnections(); ++i)
+ {
+ Layer& nextLayer = baseLayer->GetOutputSlot(0).GetConnection(i)->GetOwningLayer();
+ if (IsNCHW(nextLayer))
+ {
+ return true;
+ }
+ }
+ return false;
+}
+
+/// Checks if the Layer is connected to a Splitter Layer through a Tensor that has more than 4 dimensions.
+inline bool ConnectedToSplitterWithMoreThan4Dims(Layer* baseLayer)
+{
+ Layer& parentLayer = baseLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetOwningLayer();
+ TensorInfo parentTensorInfo = baseLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo();
+ if (parentTensorInfo.GetNumDimensions() > 4 && parentLayer.GetType() == LayerType::Splitter)
+ {
+ return true;
+ }
+ for (unsigned int i = 0; i < baseLayer->GetOutputSlot(0).GetNumConnections(); ++i)
+ {
+ Layer& nextLayer = baseLayer->GetOutputSlot(0).GetConnection(i)->GetOwningLayer();
+ TensorInfo nextTensorInfo = baseLayer->GetOutputSlot(0).GetConnection(i)->GetTensorInfo();
+ if (nextTensorInfo.GetNumDimensions() > 4 && nextLayer.GetType() == LayerType::Splitter)
+ {
+ return true;
+ }
+ }
+ return false;
+}
+
inline void RemoveReshapeLayer(ReshapeLayer* baseLayer,
std::map<LayerGuid, Layer*>& untouched,
OptimizationViews& optimizationViews)