diff options
Diffstat (limited to 'src/backends/backendsCommon/SubgraphUtils.hpp')
-rw-r--r-- | src/backends/backendsCommon/SubgraphUtils.hpp | 186 |
1 files changed, 177 insertions, 9 deletions
diff --git a/src/backends/backendsCommon/SubgraphUtils.hpp b/src/backends/backendsCommon/SubgraphUtils.hpp index bd3d698a98..ade4b63976 100644 --- a/src/backends/backendsCommon/SubgraphUtils.hpp +++ b/src/backends/backendsCommon/SubgraphUtils.hpp @@ -1,10 +1,12 @@ // -// Copyright © 2022 Arm Ltd and Contributors. All rights reserved. +// Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // #pragma once +#include <armnn/StrategyBase.hpp> +#include <armnn/Descriptors.hpp> #include <optimizations/FoldPadIntoLayer2d.hpp> namespace armnn @@ -13,6 +15,118 @@ namespace armnn namespace { +/// Checks if a Layer has a DataLayout that is either NCHW or NCDHW. +class CheckForNCHW : public StrategyBase<NoThrowStrategy> +{ +public: + CheckForNCHW() + {} + + void ExecuteStrategy(const armnn::IConnectableLayer* layer, + const armnn::BaseDescriptor& descriptor, + const std::vector<armnn::ConstTensor>& constants, + const char* name, + const armnn::LayerBindingId id = 0) override + { + armnn::IgnoreUnused(layer, constants, id, name); + switch (layer->GetType()) + { + case armnn::LayerType::BatchMatMul: + { + auto desc = static_cast<const armnn::BatchMatMulDescriptor &>(descriptor); + m_Result = desc.m_DataLayoutX == DataLayout::NCHW || desc.m_DataLayoutY == DataLayout::NCHW; + break; + } + case armnn::LayerType::BatchNormalization: + { + CheckDescForNCHW(static_cast<const armnn::BatchNormalizationDescriptor&>(descriptor)); + break; + } + case armnn::LayerType::BatchToSpaceNd: + { + CheckDescForNCHW(static_cast<const armnn::BatchToSpaceNdDescriptor&>(descriptor)); + break; + } + case armnn::LayerType::Convolution2d: + { + CheckDescForNCHW(static_cast<const armnn::Convolution2dDescriptor&>(descriptor)); + break; + } + case armnn::LayerType::Convolution3d: + { + CheckDescForNCHW(static_cast<const armnn::Convolution3dDescriptor&>(descriptor)); + break; + } + case armnn::LayerType::DepthwiseConvolution2d: + { + CheckDescForNCHW(static_cast<const armnn::DepthwiseConvolution2dDescriptor&>(descriptor)); + break; + } + case armnn::LayerType::InstanceNormalization: + { + CheckDescForNCHW(static_cast<const armnn::InstanceNormalizationDescriptor&>(descriptor)); + break; + } + case armnn::LayerType::L2Normalization: + { + CheckDescForNCHW(static_cast<const armnn::L2NormalizationDescriptor&>(descriptor)); + break; + } + case armnn::LayerType::Normalization: + { + CheckDescForNCHW(static_cast<const armnn::NormalizationDescriptor&>(descriptor)); + break; + } + case armnn::LayerType::Pooling2d: + { + CheckDescForNCHW(static_cast<const armnn::Pooling2dDescriptor&>(descriptor)); + break; + } + case armnn::LayerType::Pooling3d: + { + CheckDescForNCHW(static_cast<const armnn::Pooling3dDescriptor&>(descriptor)); + break; + } + case armnn::LayerType::SpaceToBatchNd: + { + CheckDescForNCHW(static_cast<const armnn::SpaceToBatchNdDescriptor&>(descriptor)); + break; + } + case armnn::LayerType::SpaceToDepth: + { + CheckDescForNCHW(static_cast<const armnn::SpaceToDepthDescriptor&>(descriptor)); + break; + } + case armnn::LayerType::StridedSlice: + { + CheckDescForNCHW(static_cast<const armnn::StridedSliceDescriptor&>(descriptor)); + break; + } + default: + { + m_Result = false; + } + } + } + + /// Returns true if the Layer had a DataLayout and it was NCHW or NCDHW. + /// Returns false if the Layer either doesn't have a DataLayout or if it + /// had a DataLayout that was neither NCHW nor NCDHW. + bool Result() + { + return m_Result; + } + +private: + template<typename Descriptor> + void CheckDescForNCHW(const Descriptor& descriptor) + { + m_Result = (descriptor.m_DataLayout == DataLayout::NCHW) || (descriptor.m_DataLayout == DataLayout::NCDHW); + } + + bool m_Result = false; +}; + // // this helper only works if all layers where the inputs connect to are not selected // @@ -49,6 +163,13 @@ SubgraphView::IOutputSlots CreateIOutputsFrom(const std::vector<armnn::IConnecta } +inline bool IsNCHW(armnn::Layer& layer) +{ + CheckForNCHW check; + layer.ExecuteStrategy(check); + return check.Result(); +} + inline void ReportUntouchedLayers(OptimizationViews& optimizationViews, std::map<LayerGuid, Layer*> untouched) { std::vector<Layer*> untouchedVector; @@ -78,22 +199,69 @@ LayerType* FoldPadLayer(OptimizationViews& optimizationViews, return replacementLayer; } +inline void RemoveReshapeLayer(ReshapeLayer* baseLayer, + std::map<LayerGuid, Layer*>& untouched, + OptimizationViews& optimizationViews) +{ + if (baseLayer == nullptr) + { + return; + } + ReshapeDescriptor reshapeDescriptor = baseLayer->GetParameters(); + Layer& parentLayer = baseLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetOwningLayer(); + + // Cannot currently remove the Reshape if it's connected to an Input, Constant or Splitter + if (parentLayer.GetType() == LayerType::Input || parentLayer.GetType() == LayerType::Constant) + { + return; + } + + // Cannot currently remove the Reshape if it's connected to an OutputSlot or Concat + for (unsigned int i = 0; i < baseLayer->GetOutputSlot(0).GetNumConnections(); ++i) + { + Layer& nextLayer = baseLayer->GetOutputSlot(0).GetConnection(i)->GetOwningLayer(); + + if (nextLayer.GetType() == LayerType::Output) + { + return; + } + } + auto it = untouched.find(baseLayer->GetGuid()); + if (it == untouched.end()) + { + // Already removed from map + return; + } + untouched.erase(it); + + // Override the InputSlot TensorInfos for all the layers connected to the Reshape's OutputSlot + for (unsigned int i = 0; i < baseLayer->GetOutputSlot(0).GetNumConnections(); ++i) + { + Layer& nextLayer = baseLayer->GetOutputSlot(0).GetConnection(i)->GetOwningLayer(); + auto inputIndex = baseLayer->GetOutputSlot(0).GetConnection(i)->GetSlotIndex(); + TensorInfo reshapeInfo(baseLayer->GetOutputSlot(0).GetTensorInfo()); + reshapeInfo.SetShape(reshapeDescriptor.m_TargetShape); + nextLayer.GetInputSlot(inputIndex).SetTensorInfo(reshapeInfo); + } + optimizationViews.AddDeletedSubgraph(baseLayer); +} + template<typename LayerType> LayerType* FoldPadIntoAveragePool2d(OptimizationViews& optimizationViews, Pooling2dLayer* baseLayer, Pooling2dDescriptor& poolDescriptor, PadLayer* padLayer) { - IConnectableLayer* replacement = - optimizationViews.GetINetwork()->AddPooling2dLayer(poolDescriptor, "folded-pad-into-pool2d"); - LayerType* replacementLayer = PolymorphicDowncast<LayerType*>(replacement); + IConnectableLayer* replacement = + optimizationViews.GetINetwork()->AddPooling2dLayer(poolDescriptor, "folded-pad-into-pool2d"); + LayerType* replacementLayer = PolymorphicDowncast<LayerType*>(replacement); - FoldPadLayer(optimizationViews, - baseLayer, - replacementLayer, - padLayer); + FoldPadLayer(optimizationViews, + baseLayer, + replacementLayer, + padLayer); - return replacementLayer; + return replacementLayer; } } // namespace armnn |