aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/layers/FullyConnectedLayer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/layers/FullyConnectedLayer.cpp')
-rw-r--r--src/armnn/layers/FullyConnectedLayer.cpp79
1 files changed, 11 insertions, 68 deletions
diff --git a/src/armnn/layers/FullyConnectedLayer.cpp b/src/armnn/layers/FullyConnectedLayer.cpp
index 9d4f57d260..8dfb011730 100644
--- a/src/armnn/layers/FullyConnectedLayer.cpp
+++ b/src/armnn/layers/FullyConnectedLayer.cpp
@@ -15,24 +15,20 @@ namespace armnn
{
FullyConnectedLayer::FullyConnectedLayer(const FullyConnectedDescriptor& param, const char* name)
- : LayerWithParameters(param.GetNumViews(), 1, LayerType::FullyConnected, param, name)
+ : LayerWithParameters(param.GetNumInputs(), 1, LayerType::FullyConnected, param, name)
{
}
std::unique_ptr<IWorkload> FullyConnectedLayer::CreateWorkload(const IWorkloadFactory& factory) const
{
- // on this level constant data should not be released..
FullyConnectedQueueDescriptor descriptor;
- if (m_Param.m_ConstantWeights)
+ if (m_Weight)
{
- ARMNN_ASSERT_MSG(m_Weight != nullptr, "FullyConnectedLayer: Weights data should not be null.");
descriptor.m_Weight = m_Weight.get();
-
- if (m_Param.m_BiasEnabled)
- {
- ARMNN_ASSERT_MSG(m_Bias != nullptr, "FullyConnectedLayer: Bias data should not be null.");
- descriptor.m_Bias = m_Bias.get();
- }
+ }
+ if (m_Param.m_BiasEnabled && m_Bias)
+ {
+ descriptor.m_Bias = m_Bias.get();
}
SetAdditionalInfo(descriptor);
@@ -42,15 +38,6 @@ std::unique_ptr<IWorkload> FullyConnectedLayer::CreateWorkload(const IWorkloadFa
FullyConnectedLayer* FullyConnectedLayer::Clone(Graph& graph) const
{
auto layer = CloneBase<FullyConnectedLayer>(graph, m_Param, GetName());
- if (m_Param.m_ConstantWeights)
- {
- layer->m_Weight = m_Weight ? m_Weight : nullptr;
-
- if (layer->m_Param.m_BiasEnabled)
- {
- layer->m_Bias = m_Bias ? m_Bias : nullptr;
- }
- }
return std::move(layer);
}
@@ -73,20 +60,9 @@ void FullyConnectedLayer::ValidateTensorShapesFromInputs()
VerifyShapeInferenceType(outputShape, m_ShapeInferenceMethod);
- std::vector<TensorShape> inferredShapes;
- if (m_Param.m_ConstantWeights)
- {
- // check if m_Weight data is not nullptr
- ARMNN_ASSERT_MSG(m_Weight != nullptr, "FullyConnectedLayer: Weights data should not be null.");
-
- inferredShapes = InferOutputShapes({GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape(),
- m_Weight->GetTensorInfo().GetShape()});
- }
- else
- {
- inferredShapes = InferOutputShapes({GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape(),
- GetInputSlot(1).GetConnection()->GetTensorInfo().GetShape()});
- }
+ std::vector<TensorShape> inferredShapes = InferOutputShapes(
+ {GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape(),
+ GetInputSlot(1).GetConnection()->GetTensorInfo().GetShape()});
ARMNN_ASSERT(inferredShapes.size() == 1);
ARMNN_ASSERT(inferredShapes[0].GetDimensionality() == Dimensionality::Specified);
@@ -101,45 +77,12 @@ Layer::ConstantTensors FullyConnectedLayer::GetConstantTensorsByRef()
void FullyConnectedLayer::Accept(ILayerVisitor& visitor) const
{
- Optional<ConstTensor> optionalWeightsTensor = EmptyOptional();
- Optional<ConstTensor> optionalBiasTensor = EmptyOptional();
-
- ManagedConstTensorHandle managedWeight(m_Weight);
- ManagedConstTensorHandle managedBias(m_Bias);
- if (GetParameters().m_ConstantWeights)
- {
- ConstTensor weightsTensor(managedWeight.GetTensorInfo(), managedWeight.Map());
- optionalWeightsTensor = Optional<ConstTensor>(weightsTensor);
-
- if (GetParameters().m_BiasEnabled)
- {
- ConstTensor biasTensor(managedBias.GetTensorInfo(), managedBias.Map());
- optionalBiasTensor = Optional<ConstTensor>(biasTensor);
- }
- }
-
- visitor.VisitFullyConnectedLayer(this,
- GetParameters(),
- optionalWeightsTensor.value(),
- optionalBiasTensor,
- GetName());
+ visitor.VisitFullyConnectedLayer(this, GetParameters(), GetName());
}
void FullyConnectedLayer::ExecuteStrategy(IStrategy& strategy) const
{
- std::vector <armnn::ConstTensor> constTensors;
- ManagedConstTensorHandle managedWeight(m_Weight);
- ManagedConstTensorHandle managedBias(m_Bias);
-
- if(GetParameters().m_ConstantWeights)
- {
- constTensors.emplace_back(ConstTensor(managedWeight.GetTensorInfo(), managedWeight.Map()));
- if (GetParameters().m_BiasEnabled)
- {
- constTensors.emplace_back(ConstTensor(managedBias.GetTensorInfo(), managedBias.Map()));
- }
- }
- strategy.ExecuteStrategy(this, GetParameters(), constTensors, GetName());
+ strategy.ExecuteStrategy(this, GetParameters(), {}, GetName());
}
} // namespace armnn