aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/layers/DepthwiseConvolution2dLayer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/layers/DepthwiseConvolution2dLayer.cpp')
-rw-r--r--src/armnn/layers/DepthwiseConvolution2dLayer.cpp51
1 files changed, 14 insertions, 37 deletions
diff --git a/src/armnn/layers/DepthwiseConvolution2dLayer.cpp b/src/armnn/layers/DepthwiseConvolution2dLayer.cpp
index b23661b4a8..08f6fafa1b 100644
--- a/src/armnn/layers/DepthwiseConvolution2dLayer.cpp
+++ b/src/armnn/layers/DepthwiseConvolution2dLayer.cpp
@@ -22,7 +22,7 @@ namespace armnn
DepthwiseConvolution2dLayer::DepthwiseConvolution2dLayer(const DepthwiseConvolution2dDescriptor& param,
const char* name)
- : LayerWithParameters(1, 1, LayerType::DepthwiseConvolution2d, param, name)
+ : LayerWithParameters(param.GetNumInputs(), 1, LayerType::DepthwiseConvolution2d, param, name)
{
}
@@ -31,10 +31,9 @@ void DepthwiseConvolution2dLayer::SerializeLayerParameters(ParameterStringifyFun
const std::vector<TensorShape>& inputShapes =
{
GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape(),
- m_Weight->GetTensorInfo().GetShape()
+ GetInputSlot(1).GetConnection()->GetTensorInfo().GetShape()
};
const TensorShape filterShape = inputShapes[1];
- DataLayoutIndexed dataLayoutIndex(m_Param.m_DataLayout);
unsigned int inputChannels = filterShape[1];
unsigned int filterWidth = filterShape[3];
unsigned int filterHeight = filterShape[2];
@@ -50,16 +49,14 @@ void DepthwiseConvolution2dLayer::SerializeLayerParameters(ParameterStringifyFun
std::unique_ptr<IWorkload> DepthwiseConvolution2dLayer::CreateWorkload(const IWorkloadFactory& factory) const
{
- // on this level constant data should not be released..
- ARMNN_ASSERT_MSG(m_Weight != nullptr, "DepthwiseConvolution2dLayer: Weights data should not be null.");
-
DepthwiseConvolution2dQueueDescriptor descriptor;
- descriptor.m_Weight = m_Weight.get();
-
- if (m_Param.m_BiasEnabled)
+ if (m_Weight)
+ {
+ descriptor.m_Weight = m_Weight.get();
+ }
+ if (m_Param.m_BiasEnabled && m_Bias)
{
- ARMNN_ASSERT_MSG(m_Bias != nullptr, "DepthwiseConvolution2dLayer: Bias data should not be null.");
descriptor.m_Bias = m_Bias.get();
}
@@ -124,19 +121,19 @@ DepthwiseConvolution2dLayer::InferOutputShapes(const std::vector<TensorShape>& i
void DepthwiseConvolution2dLayer::ValidateTensorShapesFromInputs()
{
- VerifyLayerConnections(1, CHECK_LOCATION());
+ VerifyLayerConnections(m_Param.GetNumInputs(), CHECK_LOCATION());
const TensorShape& outputShape = GetOutputSlot(0).GetTensorInfo().GetShape();
VerifyShapeInferenceType(outputShape, m_ShapeInferenceMethod);
- // on this level constant data should not be released..
- ARMNN_ASSERT_MSG(m_Weight != nullptr, "DepthwiseConvolution2dLayer: Weights data should not be null.");
+ ARMNN_ASSERT_MSG(GetInputSlot(1).GetConnection(),
+ "DepthwiseConvolution2dLayer: Weights data should not be null.");
auto inferredShapes = InferOutputShapes({
GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape(),
- m_Weight->GetTensorInfo().GetShape()
- });
+ GetInputSlot(1).GetConnection()->GetTensorInfo().GetShape()
+ });
ARMNN_ASSERT(inferredShapes.size() == 1);
@@ -152,33 +149,13 @@ Layer::ConstantTensors DepthwiseConvolution2dLayer::GetConstantTensorsByRef()
ARMNN_NO_DEPRECATE_WARN_BEGIN
void DepthwiseConvolution2dLayer::Accept(ILayerVisitor& visitor) const
{
- ManagedConstTensorHandle managedWeight(m_Weight);
- ConstTensor weightsTensor(managedWeight.GetTensorInfo(), managedWeight.Map());
- Optional<ConstTensor> optionalBiasTensor = EmptyOptional();
-
- ManagedConstTensorHandle managedBias(m_Bias);
- if (GetParameters().m_BiasEnabled)
- {
- ConstTensor biasTensor(managedBias.GetTensorInfo(), managedBias.Map());
- optionalBiasTensor = Optional<ConstTensor>(biasTensor);
- }
-
- visitor.VisitDepthwiseConvolution2dLayer(this, GetParameters(), weightsTensor, optionalBiasTensor, GetName());
+ visitor.VisitDepthwiseConvolution2dLayer(this, GetParameters(), GetName());
}
ARMNN_NO_DEPRECATE_WARN_END
void DepthwiseConvolution2dLayer::ExecuteStrategy(IStrategy& strategy) const
{
- ManagedConstTensorHandle managedWeight(m_Weight);
- std::vector<armnn::ConstTensor> constTensors { { managedWeight.GetTensorInfo(), managedWeight.Map() } };
-
- ManagedConstTensorHandle managedBias(m_Bias);
- if (GetParameters().m_BiasEnabled)
- {
- constTensors.emplace_back(ConstTensor(managedBias.GetTensorInfo(), managedBias.Map(true)));
- }
-
- strategy.ExecuteStrategy(this, GetParameters(), constTensors, GetName());
+ strategy.ExecuteStrategy(this, GetParameters(), {}, GetName());
}
} // namespace armnn