aboutsummaryrefslogtreecommitdiff
path: root/src/backends/backendsCommon/SubgraphUtils.hpp
diff options
context:
space:
mode:
authorMike Kelly <mike.kelly@arm.com>2023-07-18 12:03:41 +0100
committermike.kelly <mike.kelly@arm.com>2023-07-21 16:01:50 +0000
commitb6de7a1444c09c0eb44c84a923c45c041b1f6092 (patch)
treec1d9e9241854e7d9bbb09da4f4a6f4652ff5c66c /src/backends/backendsCommon/SubgraphUtils.hpp
parentc32adef195b523854144737ca05180235f5ca824 (diff)
downloadarmnn-b6de7a1444c09c0eb44c84a923c45c041b1f6092.tar.gz
IVGCVSW-7830 Clean up
* Follow up review to clean up whitespace and copyright errors mentioned in https://review.mlplatform.org/c/ml/armnn/+/9885 * Added BinaryElementwiseOperation to .dot files * Refactored ConnectedToSplitterWithMoreThan4Dims function to more generally useful ConnectedToLayerType function Signed-off-by: Mike Kelly <mike.kelly@arm.com> Change-Id: I0e3d0895888f3a3f0a9758ce30bc031aba50812b
Diffstat (limited to 'src/backends/backendsCommon/SubgraphUtils.hpp')
-rw-r--r--src/backends/backendsCommon/SubgraphUtils.hpp15
1 files changed, 9 insertions, 6 deletions
diff --git a/src/backends/backendsCommon/SubgraphUtils.hpp b/src/backends/backendsCommon/SubgraphUtils.hpp
index 823da76f29..9f2cdba6ef 100644
--- a/src/backends/backendsCommon/SubgraphUtils.hpp
+++ b/src/backends/backendsCommon/SubgraphUtils.hpp
@@ -33,7 +33,7 @@ public:
{
case armnn::LayerType::BatchMatMul:
{
- auto desc = static_cast<const armnn::BatchMatMulDescriptor &>(descriptor);
+ auto desc = static_cast<const armnn::BatchMatMulDescriptor&>(descriptor);
m_Result = desc.m_DataLayoutX == DataLayout::NCHW || desc.m_DataLayoutY == DataLayout::NCHW;
break;
}
@@ -219,12 +219,14 @@ inline bool ConnectedToLayerWithNCHW(Layer* baseLayer)
return false;
}
-/// Checks if the Layer is connected to a Splitter Layer through a Tensor that has more than 4 dimensions.
-inline bool ConnectedToSplitterWithMoreThan4Dims(Layer* baseLayer)
+/// Checks the Layer's Connections to see if it's connected to a Layer with the provided layerType. If dimSize is
+/// provided will also check if the connecting Tensor has more than that number of dimensions
+inline bool ConnectedToLayerType(Layer* baseLayer, LayerType layerType, unsigned int dimSize = 0)
{
Layer& parentLayer = baseLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetOwningLayer();
- TensorInfo parentTensorInfo = baseLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo();
- if (parentTensorInfo.GetNumDimensions() > 4 && parentLayer.GetType() == LayerType::Splitter)
+ TensorInfo parentTensorInfo = baseLayer->GetInputSlot(0).GetTensorInfo();
+
+ if (parentTensorInfo.GetNumDimensions() > dimSize && parentLayer.GetType() == layerType)
{
return true;
}
@@ -232,7 +234,8 @@ inline bool ConnectedToSplitterWithMoreThan4Dims(Layer* baseLayer)
{
Layer& nextLayer = baseLayer->GetOutputSlot(0).GetConnection(i)->GetOwningLayer();
TensorInfo nextTensorInfo = baseLayer->GetOutputSlot(0).GetConnection(i)->GetTensorInfo();
- if (nextTensorInfo.GetNumDimensions() > 4 && nextLayer.GetType() == LayerType::Splitter)
+
+ if (nextTensorInfo.GetNumDimensions() > dimSize && nextLayer.GetType() == layerType)
{
return true;
}