diff options
Diffstat (limited to 'src/armnn/optimizations/FoldPadIntoLayer2d.hpp')
-rw-r--r-- | src/armnn/optimizations/FoldPadIntoLayer2d.hpp | 40 |
1 files changed, 40 insertions, 0 deletions
diff --git a/src/armnn/optimizations/FoldPadIntoLayer2d.hpp b/src/armnn/optimizations/FoldPadIntoLayer2d.hpp index 637f2b36d3..73188883b2 100644 --- a/src/armnn/optimizations/FoldPadIntoLayer2d.hpp +++ b/src/armnn/optimizations/FoldPadIntoLayer2d.hpp @@ -58,6 +58,13 @@ inline bool IsNeutralElement(const Convolution2dDescriptor&, const TensorInfo& t return tensorValue == GetZeroElement(tensorInfo); } +inline bool IsNeutralElement(const DepthwiseConvolution2dDescriptor&, + const TensorInfo& tensorInfo, + const float tensorValue) +{ + return tensorValue == GetZeroElement(tensorInfo); +} + inline bool IsNeutralElement( const Pooling2dDescriptor& descriptor, const TensorInfo& tensorInfo, const float tensorValue) { @@ -179,6 +186,35 @@ protected: ~FoldPadIntoConvolution2dImpl() = default; }; +class FoldPadIntoDepthwiseConvolution2dImpl +{ +public: + void Run(Graph& graph, InputSlot& connection) const + { + const auto newConv2dLayer = FoldPadIntoLayer2dImpl<DepthwiseConvolution2dLayer>(graph, connection); + + if (newConv2dLayer != nullptr) + { + const auto conv2dLayer = PolymorphicDowncast<DepthwiseConvolution2dLayer*>(&connection.GetOwningLayer()); + // Copy weights and bias to the new convolution layer + ARMNN_ASSERT_MSG(conv2dLayer->m_Weight != nullptr, + "FoldPadIntoDepthwiseConvolution2d: Weights data should not be null."); + newConv2dLayer->m_Weight = std::move(conv2dLayer->m_Weight); + + if (conv2dLayer->GetParameters().m_BiasEnabled) + { + ARMNN_ASSERT_MSG(conv2dLayer->m_Bias != nullptr, + "FoldPadIntoDepthwiseConvolution2d: Bias data should not be null if bias is enabled."); + newConv2dLayer->m_Bias = std::move(conv2dLayer->m_Bias); + } + } + } + +protected: + FoldPadIntoDepthwiseConvolution2dImpl() = default; + ~FoldPadIntoDepthwiseConvolution2dImpl() = default; +}; + class FoldPadIntoPooling2dImpl { public: @@ -195,6 +231,10 @@ protected: using FoldPadIntoConvolution2d = OptimizeForExclusiveConnection<PadLayer, Convolution2dLayer, pad_fold::FoldPadIntoConvolution2dImpl>; +using FoldPadIntoDepthwiseConvolution2d = + OptimizeForExclusiveConnection <PadLayer, + DepthwiseConvolution2dLayer, + pad_fold::FoldPadIntoDepthwiseConvolution2dImpl>; using FoldPadIntoPooling2d = OptimizeForExclusiveConnection<PadLayer, Pooling2dLayer, pad_fold::FoldPadIntoPooling2dImpl>; |