aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/layers/FullyConnectedLayer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/layers/FullyConnectedLayer.cpp')
-rw-r--r--src/armnn/layers/FullyConnectedLayer.cpp17
1 files changed, 12 insertions, 5 deletions
diff --git a/src/armnn/layers/FullyConnectedLayer.cpp b/src/armnn/layers/FullyConnectedLayer.cpp
index 44c8920136..79d56c0bd7 100644
--- a/src/armnn/layers/FullyConnectedLayer.cpp
+++ b/src/armnn/layers/FullyConnectedLayer.cpp
@@ -103,17 +103,21 @@ void FullyConnectedLayer::Accept(ILayerVisitor& visitor) const
{
Optional<ConstTensor> optionalWeightsTensor = EmptyOptional();
Optional<ConstTensor> optionalBiasTensor = EmptyOptional();
- if(GetParameters().m_ConstantWeights)
+
+ ManagedConstTensorHandle managedWeight(m_Weight);
+ ManagedConstTensorHandle managedBias(m_Bias);
+ if (GetParameters().m_ConstantWeights)
{
- ConstTensor weightsTensor(m_Weight->GetTensorInfo(), m_Weight->GetConstTensor<void>());
+ ConstTensor weightsTensor(managedWeight.GetTensorInfo(), managedWeight.Map());
optionalWeightsTensor = Optional<ConstTensor>(weightsTensor);
if (GetParameters().m_BiasEnabled)
{
- ConstTensor biasTensor(m_Bias->GetTensorInfo(), m_Bias->GetConstTensor<void>());
+ ConstTensor biasTensor(managedBias.GetTensorInfo(), managedBias.Map());
optionalBiasTensor = Optional<ConstTensor>(biasTensor);
}
}
+
visitor.VisitFullyConnectedLayer(this,
GetParameters(),
optionalWeightsTensor.value(),
@@ -124,12 +128,15 @@ void FullyConnectedLayer::Accept(ILayerVisitor& visitor) const
void FullyConnectedLayer::ExecuteStrategy(IStrategy& strategy) const
{
std::vector <armnn::ConstTensor> constTensors;
+ ManagedConstTensorHandle managedWeight(m_Weight);
+ ManagedConstTensorHandle managedBias(m_Bias);
+
if(GetParameters().m_ConstantWeights)
{
- constTensors.emplace_back(ConstTensor(m_Weight->GetTensorInfo(), m_Weight->Map(true)));
+ constTensors.emplace_back(ConstTensor(managedWeight.GetTensorInfo(), managedWeight.Map()));
if (GetParameters().m_BiasEnabled)
{
- constTensors.emplace_back(ConstTensor(m_Bias->GetTensorInfo(), m_Bias->Map(true)));
+ constTensors.emplace_back(ConstTensor(managedBias.GetTensorInfo(), managedBias.Map()));
}
}
strategy.ExecuteStrategy(this, GetParameters(), constTensors, GetName());