diff options
Diffstat (limited to 'src/armnn/Network.cpp')
-rw-r--r-- | src/armnn/Network.cpp | 12 |
1 files changed, 10 insertions, 2 deletions
diff --git a/src/armnn/Network.cpp b/src/armnn/Network.cpp index 22a71c4923..365f1bdfa1 100644 --- a/src/armnn/Network.cpp +++ b/src/armnn/Network.cpp @@ -1805,7 +1805,11 @@ IConnectableLayer* NetworkImpl::AddFullyConnectedLayer(const FullyConnectedDescr { weightsLayer = m_Graph->AddLayer<ConstantLayer>("Weights"); weightsLayer->m_LayerOutput = std::make_shared<ScopedTensorHandle>(weights.value()); - weightsLayer->GetOutputSlot(0).SetTensorInfo(weightsLayer->m_LayerOutput->GetTensorInfo()); + + TensorInfo weightsInfo = weightsLayer->m_LayerOutput->GetTensorInfo(); + weightsInfo.SetConstant(); + + weightsLayer->GetOutputSlot(0).SetTensorInfo(weightsInfo); } else if (fullyConnectedDescriptor.m_ConstantWeights) { @@ -1817,7 +1821,11 @@ IConnectableLayer* NetworkImpl::AddFullyConnectedLayer(const FullyConnectedDescr { biasLayer = m_Graph->AddLayer<ConstantLayer>("Biases"); biasLayer->m_LayerOutput = std::make_shared<ScopedTensorHandle>(biases.value()); - biasLayer->GetOutputSlot(0).SetTensorInfo(biasLayer->m_LayerOutput->GetTensorInfo()); + + TensorInfo biasInfo = biasLayer->m_LayerOutput->GetTensorInfo(); + biasInfo.SetConstant(); + + biasLayer->GetOutputSlot(0).SetTensorInfo(biasInfo); } if (numInputs < 2) |