diff options
Diffstat (limited to 'src/armnn/layers/DepthwiseConvolution2dLayer.cpp')
-rw-r--r-- | src/armnn/layers/DepthwiseConvolution2dLayer.cpp | 51 |
1 files changed, 14 insertions, 37 deletions
diff --git a/src/armnn/layers/DepthwiseConvolution2dLayer.cpp b/src/armnn/layers/DepthwiseConvolution2dLayer.cpp index b23661b4a8..08f6fafa1b 100644 --- a/src/armnn/layers/DepthwiseConvolution2dLayer.cpp +++ b/src/armnn/layers/DepthwiseConvolution2dLayer.cpp @@ -22,7 +22,7 @@ namespace armnn DepthwiseConvolution2dLayer::DepthwiseConvolution2dLayer(const DepthwiseConvolution2dDescriptor& param, const char* name) - : LayerWithParameters(1, 1, LayerType::DepthwiseConvolution2d, param, name) + : LayerWithParameters(param.GetNumInputs(), 1, LayerType::DepthwiseConvolution2d, param, name) { } @@ -31,10 +31,9 @@ void DepthwiseConvolution2dLayer::SerializeLayerParameters(ParameterStringifyFun const std::vector<TensorShape>& inputShapes = { GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape(), - m_Weight->GetTensorInfo().GetShape() + GetInputSlot(1).GetConnection()->GetTensorInfo().GetShape() }; const TensorShape filterShape = inputShapes[1]; - DataLayoutIndexed dataLayoutIndex(m_Param.m_DataLayout); unsigned int inputChannels = filterShape[1]; unsigned int filterWidth = filterShape[3]; unsigned int filterHeight = filterShape[2]; @@ -50,16 +49,14 @@ void DepthwiseConvolution2dLayer::SerializeLayerParameters(ParameterStringifyFun std::unique_ptr<IWorkload> DepthwiseConvolution2dLayer::CreateWorkload(const IWorkloadFactory& factory) const { - // on this level constant data should not be released.. - ARMNN_ASSERT_MSG(m_Weight != nullptr, "DepthwiseConvolution2dLayer: Weights data should not be null."); - DepthwiseConvolution2dQueueDescriptor descriptor; - descriptor.m_Weight = m_Weight.get(); - - if (m_Param.m_BiasEnabled) + if (m_Weight) + { + descriptor.m_Weight = m_Weight.get(); + } + if (m_Param.m_BiasEnabled && m_Bias) { - ARMNN_ASSERT_MSG(m_Bias != nullptr, "DepthwiseConvolution2dLayer: Bias data should not be null."); descriptor.m_Bias = m_Bias.get(); } @@ -124,19 +121,19 @@ DepthwiseConvolution2dLayer::InferOutputShapes(const std::vector<TensorShape>& i void DepthwiseConvolution2dLayer::ValidateTensorShapesFromInputs() { - VerifyLayerConnections(1, CHECK_LOCATION()); + VerifyLayerConnections(m_Param.GetNumInputs(), CHECK_LOCATION()); const TensorShape& outputShape = GetOutputSlot(0).GetTensorInfo().GetShape(); VerifyShapeInferenceType(outputShape, m_ShapeInferenceMethod); - // on this level constant data should not be released.. - ARMNN_ASSERT_MSG(m_Weight != nullptr, "DepthwiseConvolution2dLayer: Weights data should not be null."); + ARMNN_ASSERT_MSG(GetInputSlot(1).GetConnection(), + "DepthwiseConvolution2dLayer: Weights data should not be null."); auto inferredShapes = InferOutputShapes({ GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape(), - m_Weight->GetTensorInfo().GetShape() - }); + GetInputSlot(1).GetConnection()->GetTensorInfo().GetShape() + }); ARMNN_ASSERT(inferredShapes.size() == 1); @@ -152,33 +149,13 @@ Layer::ConstantTensors DepthwiseConvolution2dLayer::GetConstantTensorsByRef() ARMNN_NO_DEPRECATE_WARN_BEGIN void DepthwiseConvolution2dLayer::Accept(ILayerVisitor& visitor) const { - ManagedConstTensorHandle managedWeight(m_Weight); - ConstTensor weightsTensor(managedWeight.GetTensorInfo(), managedWeight.Map()); - Optional<ConstTensor> optionalBiasTensor = EmptyOptional(); - - ManagedConstTensorHandle managedBias(m_Bias); - if (GetParameters().m_BiasEnabled) - { - ConstTensor biasTensor(managedBias.GetTensorInfo(), managedBias.Map()); - optionalBiasTensor = Optional<ConstTensor>(biasTensor); - } - - visitor.VisitDepthwiseConvolution2dLayer(this, GetParameters(), weightsTensor, optionalBiasTensor, GetName()); + visitor.VisitDepthwiseConvolution2dLayer(this, GetParameters(), GetName()); } ARMNN_NO_DEPRECATE_WARN_END void DepthwiseConvolution2dLayer::ExecuteStrategy(IStrategy& strategy) const { - ManagedConstTensorHandle managedWeight(m_Weight); - std::vector<armnn::ConstTensor> constTensors { { managedWeight.GetTensorInfo(), managedWeight.Map() } }; - - ManagedConstTensorHandle managedBias(m_Bias); - if (GetParameters().m_BiasEnabled) - { - constTensors.emplace_back(ConstTensor(managedBias.GetTensorInfo(), managedBias.Map(true))); - } - - strategy.ExecuteStrategy(this, GetParameters(), constTensors, GetName()); + strategy.ExecuteStrategy(this, GetParameters(), {}, GetName()); } } // namespace armnn |