aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/layers/Convolution2dLayer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/layers/Convolution2dLayer.cpp')
-rw-r--r--src/armnn/layers/Convolution2dLayer.cpp51
1 files changed, 15 insertions, 36 deletions
diff --git a/src/armnn/layers/Convolution2dLayer.cpp b/src/armnn/layers/Convolution2dLayer.cpp
index ef5db8e9b9..7b3382bf93 100644
--- a/src/armnn/layers/Convolution2dLayer.cpp
+++ b/src/armnn/layers/Convolution2dLayer.cpp
@@ -21,7 +21,7 @@ namespace armnn
{
Convolution2dLayer::Convolution2dLayer(const Convolution2dDescriptor& param, const char* name)
- : LayerWithParameters(1, 1, LayerType::Convolution2d, param, name)
+ : LayerWithParameters(param.GetNumInputs(), 1, LayerType::Convolution2d, param, name)
{
}
@@ -32,7 +32,7 @@ void Convolution2dLayer::SerializeLayerParameters(ParameterStringifyFunction& fn
const std::vector<TensorShape>& inputShapes =
{
GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape(),
- m_Weight->GetTensorInfo().GetShape()
+ GetInputSlot(1).GetConnection()->GetTensorInfo().GetShape()
};
const TensorShape filterShape = inputShapes[1];
DataLayoutIndexed dataLayoutIndex(m_Param.m_DataLayout);
@@ -49,15 +49,14 @@ void Convolution2dLayer::SerializeLayerParameters(ParameterStringifyFunction& fn
std::unique_ptr<IWorkload> Convolution2dLayer::CreateWorkload(const IWorkloadFactory& factory) const
{
// on this level constant data should not be released..
- ARMNN_ASSERT_MSG(m_Weight != nullptr, "Convolution2dLayer: Weights data should not be null.");
ARMNN_SCOPED_PROFILING_EVENT(Compute::Undefined, "Convolution2dLayer_CreateWorkload");
Convolution2dQueueDescriptor descriptor;
-
- descriptor.m_Weight = m_Weight.get();
-
- if (m_Param.m_BiasEnabled)
+ if (m_Weight)
+ {
+ descriptor.m_Weight = m_Weight.get();
+ }
+ if (m_Param.m_BiasEnabled && m_Bias)
{
- ARMNN_ASSERT_MSG(m_Bias != nullptr, "Convolution2dLayer: Bias data should not be null.");
descriptor.m_Bias = m_Bias.get();
}
@@ -120,18 +119,18 @@ std::vector<TensorShape> Convolution2dLayer::InferOutputShapes(const std::vector
void Convolution2dLayer::ValidateTensorShapesFromInputs()
{
- VerifyLayerConnections(1, CHECK_LOCATION());
+ VerifyLayerConnections(m_Param.GetNumInputs(), CHECK_LOCATION());
const TensorShape& outputShape = GetOutputSlot(0).GetTensorInfo().GetShape();
VerifyShapeInferenceType(outputShape, m_ShapeInferenceMethod);
- // check if we m_Weight data is not nullptr
- ARMNN_ASSERT_MSG(m_Weight != nullptr, "Convolution2dLayer: Weights data should not be null.");
+ ARMNN_ASSERT_MSG(GetInputSlot(1).GetConnection(),
+ "Convolution2dLayer: Weights should be connected to input slot 1.");
- auto inferredShapes = InferOutputShapes({
- GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape(),
- m_Weight->GetTensorInfo().GetShape() });
+ std::vector<TensorShape> inferredShapes = InferOutputShapes({
+ GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape(),
+ GetInputSlot(1).GetConnection()->GetTensorInfo().GetShape() });
ARMNN_ASSERT(inferredShapes.size() == 1);
@@ -147,33 +146,13 @@ Layer::ConstantTensors Convolution2dLayer::GetConstantTensorsByRef()
ARMNN_NO_DEPRECATE_WARN_BEGIN
void Convolution2dLayer::Accept(ILayerVisitor& visitor) const
{
- ManagedConstTensorHandle managedWeight(m_Weight);
- ConstTensor weightsTensor(managedWeight.GetTensorInfo(), managedWeight.Map());
-
- Optional<ConstTensor> optionalBiasTensor = EmptyOptional();
- ManagedConstTensorHandle managedBias(m_Bias);
- if (GetParameters().m_BiasEnabled)
- {
- ConstTensor biasTensor(managedBias.GetTensorInfo(), managedBias.Map());
- optionalBiasTensor = Optional<ConstTensor>(biasTensor);
- }
-
- visitor.VisitConvolution2dLayer(this, GetParameters(), weightsTensor, optionalBiasTensor, GetName());
+ visitor.VisitConvolution2dLayer(this, GetParameters(), GetName());
}
ARMNN_NO_DEPRECATE_WARN_END
void Convolution2dLayer::ExecuteStrategy(IStrategy& strategy) const
{
- ManagedConstTensorHandle managedWeight(m_Weight);
- std::vector<armnn::ConstTensor> constTensors { { managedWeight.GetTensorInfo(), managedWeight.Map() } };
-
- ManagedConstTensorHandle managedBias(m_Bias);
- if (GetParameters().m_BiasEnabled)
- {
- constTensors.emplace_back(ConstTensor(managedBias.GetTensorInfo(), managedBias.Map()));
- }
-
- strategy.ExecuteStrategy(this, GetParameters(), constTensors, GetName());
+ strategy.ExecuteStrategy(this, GetParameters(), { }, GetName());
}
} // namespace armnn