aboutsummaryrefslogtreecommitdiff
path: root/src/backends/backendsCommon/SubgraphUtils.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/backendsCommon/SubgraphUtils.hpp')
-rw-r--r--src/backends/backendsCommon/SubgraphUtils.hpp186
1 files changed, 177 insertions, 9 deletions
diff --git a/src/backends/backendsCommon/SubgraphUtils.hpp b/src/backends/backendsCommon/SubgraphUtils.hpp
index bd3d698a98..ade4b63976 100644
--- a/src/backends/backendsCommon/SubgraphUtils.hpp
+++ b/src/backends/backendsCommon/SubgraphUtils.hpp
@@ -1,10 +1,12 @@
//
-// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
+// Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
#pragma once
+#include <armnn/StrategyBase.hpp>
+#include <armnn/Descriptors.hpp>
#include <optimizations/FoldPadIntoLayer2d.hpp>
namespace armnn
@@ -13,6 +15,118 @@ namespace armnn
namespace
{
+/// Checks if a Layer has a DataLayout that is either NCHW or NCDHW.
+class CheckForNCHW : public StrategyBase<NoThrowStrategy>
+{
+public:
+ CheckForNCHW()
+ {}
+
+ void ExecuteStrategy(const armnn::IConnectableLayer* layer,
+ const armnn::BaseDescriptor& descriptor,
+ const std::vector<armnn::ConstTensor>& constants,
+ const char* name,
+ const armnn::LayerBindingId id = 0) override
+ {
+ armnn::IgnoreUnused(layer, constants, id, name);
+ switch (layer->GetType())
+ {
+ case armnn::LayerType::BatchMatMul:
+ {
+ auto desc = static_cast<const armnn::BatchMatMulDescriptor &>(descriptor);
+ m_Result = desc.m_DataLayoutX == DataLayout::NCHW || desc.m_DataLayoutY == DataLayout::NCHW;
+ break;
+ }
+ case armnn::LayerType::BatchNormalization:
+ {
+ CheckDescForNCHW(static_cast<const armnn::BatchNormalizationDescriptor&>(descriptor));
+ break;
+ }
+ case armnn::LayerType::BatchToSpaceNd:
+ {
+ CheckDescForNCHW(static_cast<const armnn::BatchToSpaceNdDescriptor&>(descriptor));
+ break;
+ }
+ case armnn::LayerType::Convolution2d:
+ {
+ CheckDescForNCHW(static_cast<const armnn::Convolution2dDescriptor&>(descriptor));
+ break;
+ }
+ case armnn::LayerType::Convolution3d:
+ {
+ CheckDescForNCHW(static_cast<const armnn::Convolution3dDescriptor&>(descriptor));
+ break;
+ }
+ case armnn::LayerType::DepthwiseConvolution2d:
+ {
+ CheckDescForNCHW(static_cast<const armnn::DepthwiseConvolution2dDescriptor&>(descriptor));
+ break;
+ }
+ case armnn::LayerType::InstanceNormalization:
+ {
+ CheckDescForNCHW(static_cast<const armnn::InstanceNormalizationDescriptor&>(descriptor));
+ break;
+ }
+ case armnn::LayerType::L2Normalization:
+ {
+ CheckDescForNCHW(static_cast<const armnn::L2NormalizationDescriptor&>(descriptor));
+ break;
+ }
+ case armnn::LayerType::Normalization:
+ {
+ CheckDescForNCHW(static_cast<const armnn::NormalizationDescriptor&>(descriptor));
+ break;
+ }
+ case armnn::LayerType::Pooling2d:
+ {
+ CheckDescForNCHW(static_cast<const armnn::Pooling2dDescriptor&>(descriptor));
+ break;
+ }
+ case armnn::LayerType::Pooling3d:
+ {
+ CheckDescForNCHW(static_cast<const armnn::Pooling3dDescriptor&>(descriptor));
+ break;
+ }
+ case armnn::LayerType::SpaceToBatchNd:
+ {
+ CheckDescForNCHW(static_cast<const armnn::SpaceToBatchNdDescriptor&>(descriptor));
+ break;
+ }
+ case armnn::LayerType::SpaceToDepth:
+ {
+ CheckDescForNCHW(static_cast<const armnn::SpaceToDepthDescriptor&>(descriptor));
+ break;
+ }
+ case armnn::LayerType::StridedSlice:
+ {
+ CheckDescForNCHW(static_cast<const armnn::StridedSliceDescriptor&>(descriptor));
+ break;
+ }
+ default:
+ {
+ m_Result = false;
+ }
+ }
+ }
+
+ /// Returns true if the Layer had a DataLayout and it was NCHW or NCDHW.
+ /// Returns false if the Layer either doesn't have a DataLayout or if it
+ /// had a DataLayout that was neither NCHW nor NCDHW.
+ bool Result()
+ {
+ return m_Result;
+ }
+
+private:
+ template<typename Descriptor>
+ void CheckDescForNCHW(const Descriptor& descriptor)
+ {
+ m_Result = (descriptor.m_DataLayout == DataLayout::NCHW) || (descriptor.m_DataLayout == DataLayout::NCDHW);
+ }
+
+ bool m_Result = false;
+};
+
//
// this helper only works if all layers where the inputs connect to are not selected
//
@@ -49,6 +163,13 @@ SubgraphView::IOutputSlots CreateIOutputsFrom(const std::vector<armnn::IConnecta
}
+inline bool IsNCHW(armnn::Layer& layer)
+{
+ CheckForNCHW check;
+ layer.ExecuteStrategy(check);
+ return check.Result();
+}
+
inline void ReportUntouchedLayers(OptimizationViews& optimizationViews, std::map<LayerGuid, Layer*> untouched)
{
std::vector<Layer*> untouchedVector;
@@ -78,22 +199,69 @@ LayerType* FoldPadLayer(OptimizationViews& optimizationViews,
return replacementLayer;
}
+inline void RemoveReshapeLayer(ReshapeLayer* baseLayer,
+ std::map<LayerGuid, Layer*>& untouched,
+ OptimizationViews& optimizationViews)
+{
+ if (baseLayer == nullptr)
+ {
+ return;
+ }
+ ReshapeDescriptor reshapeDescriptor = baseLayer->GetParameters();
+ Layer& parentLayer = baseLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetOwningLayer();
+
+ // Cannot currently remove the Reshape if it's connected to an Input, Constant or Splitter
+ if (parentLayer.GetType() == LayerType::Input || parentLayer.GetType() == LayerType::Constant)
+ {
+ return;
+ }
+
+ // Cannot currently remove the Reshape if it's connected to an OutputSlot or Concat
+ for (unsigned int i = 0; i < baseLayer->GetOutputSlot(0).GetNumConnections(); ++i)
+ {
+ Layer& nextLayer = baseLayer->GetOutputSlot(0).GetConnection(i)->GetOwningLayer();
+
+ if (nextLayer.GetType() == LayerType::Output)
+ {
+ return;
+ }
+ }
+ auto it = untouched.find(baseLayer->GetGuid());
+ if (it == untouched.end())
+ {
+ // Already removed from map
+ return;
+ }
+ untouched.erase(it);
+
+ // Override the InputSlot TensorInfos for all the layers connected to the Reshape's OutputSlot
+ for (unsigned int i = 0; i < baseLayer->GetOutputSlot(0).GetNumConnections(); ++i)
+ {
+ Layer& nextLayer = baseLayer->GetOutputSlot(0).GetConnection(i)->GetOwningLayer();
+ auto inputIndex = baseLayer->GetOutputSlot(0).GetConnection(i)->GetSlotIndex();
+ TensorInfo reshapeInfo(baseLayer->GetOutputSlot(0).GetTensorInfo());
+ reshapeInfo.SetShape(reshapeDescriptor.m_TargetShape);
+ nextLayer.GetInputSlot(inputIndex).SetTensorInfo(reshapeInfo);
+ }
+ optimizationViews.AddDeletedSubgraph(baseLayer);
+}
+
template<typename LayerType>
LayerType* FoldPadIntoAveragePool2d(OptimizationViews& optimizationViews,
Pooling2dLayer* baseLayer,
Pooling2dDescriptor& poolDescriptor,
PadLayer* padLayer)
{
- IConnectableLayer* replacement =
- optimizationViews.GetINetwork()->AddPooling2dLayer(poolDescriptor, "folded-pad-into-pool2d");
- LayerType* replacementLayer = PolymorphicDowncast<LayerType*>(replacement);
+ IConnectableLayer* replacement =
+ optimizationViews.GetINetwork()->AddPooling2dLayer(poolDescriptor, "folded-pad-into-pool2d");
+ LayerType* replacementLayer = PolymorphicDowncast<LayerType*>(replacement);
- FoldPadLayer(optimizationViews,
- baseLayer,
- replacementLayer,
- padLayer);
+ FoldPadLayer(optimizationViews,
+ baseLayer,
+ replacementLayer,
+ padLayer);
- return replacementLayer;
+ return replacementLayer;
}
} // namespace armnn