diff options
Diffstat (limited to 'src/backends/backendsCommon/SubgraphUtils.hpp')
-rw-r--r-- | src/backends/backendsCommon/SubgraphUtils.hpp | 15 |
1 files changed, 9 insertions, 6 deletions
diff --git a/src/backends/backendsCommon/SubgraphUtils.hpp b/src/backends/backendsCommon/SubgraphUtils.hpp index 823da76f29..9f2cdba6ef 100644 --- a/src/backends/backendsCommon/SubgraphUtils.hpp +++ b/src/backends/backendsCommon/SubgraphUtils.hpp @@ -33,7 +33,7 @@ public: { case armnn::LayerType::BatchMatMul: { - auto desc = static_cast<const armnn::BatchMatMulDescriptor &>(descriptor); + auto desc = static_cast<const armnn::BatchMatMulDescriptor&>(descriptor); m_Result = desc.m_DataLayoutX == DataLayout::NCHW || desc.m_DataLayoutY == DataLayout::NCHW; break; } @@ -219,12 +219,14 @@ inline bool ConnectedToLayerWithNCHW(Layer* baseLayer) return false; } -/// Checks if the Layer is connected to a Splitter Layer through a Tensor that has more than 4 dimensions. -inline bool ConnectedToSplitterWithMoreThan4Dims(Layer* baseLayer) +/// Checks the Layer's Connections to see if it's connected to a Layer with the provided layerType. If dimSize is +/// provided will also check if the connecting Tensor has more than that number of dimensions +inline bool ConnectedToLayerType(Layer* baseLayer, LayerType layerType, unsigned int dimSize = 0) { Layer& parentLayer = baseLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetOwningLayer(); - TensorInfo parentTensorInfo = baseLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo(); - if (parentTensorInfo.GetNumDimensions() > 4 && parentLayer.GetType() == LayerType::Splitter) + TensorInfo parentTensorInfo = baseLayer->GetInputSlot(0).GetTensorInfo(); + + if (parentTensorInfo.GetNumDimensions() > dimSize && parentLayer.GetType() == layerType) { return true; } @@ -232,7 +234,8 @@ inline bool ConnectedToSplitterWithMoreThan4Dims(Layer* baseLayer) { Layer& nextLayer = baseLayer->GetOutputSlot(0).GetConnection(i)->GetOwningLayer(); TensorInfo nextTensorInfo = baseLayer->GetOutputSlot(0).GetConnection(i)->GetTensorInfo(); - if (nextTensorInfo.GetNumDimensions() > 4 && nextLayer.GetType() == LayerType::Splitter) + + if (nextTensorInfo.GetNumDimensions() > dimSize && nextLayer.GetType() == layerType) { return true; } |