aboutsummaryrefslogtreecommitdiff
path: root/src/backends/neon/workloads/NeonFullyConnectedWorkload.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/neon/workloads/NeonFullyConnectedWorkload.cpp')
-rw-r--r--src/backends/neon/workloads/NeonFullyConnectedWorkload.cpp33
1 files changed, 22 insertions, 11 deletions
diff --git a/src/backends/neon/workloads/NeonFullyConnectedWorkload.cpp b/src/backends/neon/workloads/NeonFullyConnectedWorkload.cpp
index 0b91eb37c2..7bb23f870b 100644
--- a/src/backends/neon/workloads/NeonFullyConnectedWorkload.cpp
+++ b/src/backends/neon/workloads/NeonFullyConnectedWorkload.cpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
+// Copyright © 2017,2022 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -70,15 +70,15 @@ NeonFullyConnectedWorkload::NeonFullyConnectedWorkload(const FullyConnectedQueue
// Copy the weights' tensor into arm_compute tensor.
m_WeightsTensor = std::make_unique<arm_compute::Tensor>();
- BuildArmComputeTensor(*m_WeightsTensor, m_Data.m_Weight->GetTensorInfo());
- InitializeArmComputeTensorData(*m_WeightsTensor, m_Data.m_Weight);
-
+ m_WeightsTensorInfo = info.m_InputTensorInfos[1];
+ BuildArmComputeTensor(*m_WeightsTensor, m_WeightsTensorInfo);
+
if (m_Data.m_Parameters.m_BiasEnabled)
{
// Copy the biases tensor into arm_compute tensor.
m_BiasesTensor = std::make_unique<arm_compute::Tensor>();
- BuildArmComputeTensor(*m_BiasesTensor, m_Data.m_Bias->GetTensorInfo());
- InitializeArmComputeTensorData(*m_BiasesTensor, m_Data.m_Bias);
+ m_BiasesTensorInfo = info.m_InputTensorInfos[2];
+ BuildArmComputeTensor(*m_BiasesTensor, m_BiasesTensorInfo);
}
const arm_compute::ActivationLayerInfo activationInfo = ConvertAdditionalInfoToAclActivationLayerInfo(descriptor);
@@ -94,10 +94,10 @@ NeonFullyConnectedWorkload::NeonFullyConnectedWorkload(const FullyConnectedQueue
detailsInfo.m_InputTensorInfos = info.m_InputTensorInfos;
detailsInfo.m_OutputTensorInfos = info.m_OutputTensorInfos;
- detailsInfo.m_WeightsTensorInfo = armnn::Optional<armnn::TensorInfo>(descriptor.m_Weight->GetTensorInfo());
+ detailsInfo.m_WeightsTensorInfo = armnn::Optional<armnn::TensorInfo>(info.m_InputTensorInfos[1]);
if (descriptor.m_Parameters.m_BiasEnabled)
{
- detailsInfo.m_BiasTensorInfo = armnn::Optional<armnn::TensorInfo>(descriptor.m_Bias->GetTensorInfo());
+ detailsInfo.m_BiasTensorInfo = armnn::Optional<armnn::TensorInfo>(info.m_InputTensorInfos[2]);
}
// Report Profiling Details
@@ -107,14 +107,25 @@ NeonFullyConnectedWorkload::NeonFullyConnectedWorkload(const FullyConnectedQueue
this->GetGuid());
// Force Compute Library to perform the necessary copying and reshaping.
- m_FullyConnectedLayer->prepare();
- FreeTensorIfUnused(m_WeightsTensor);
- FreeTensorIfUnused(m_BiasesTensor);
}
void NeonFullyConnectedWorkload::Execute() const
{
ARMNN_SCOPED_PROFILING_EVENT_NEON_GUID("NeonFullyConnectedWorkload_Execute", this->GetGuid());
+ // The constant tensors may not be fully in place until the workload is Executed
+ if (!prepared)
+ {
+ InitializeArmComputeTensorData(*m_WeightsTensor, m_WeightsTensorInfo, m_Data.m_Inputs[1]);
+
+ if (m_Data.m_Parameters.m_BiasEnabled)
+ {
+ InitializeArmComputeTensorData(*m_BiasesTensor, m_BiasesTensorInfo, m_Data.m_Inputs[2]);
+ }
+ m_FullyConnectedLayer->prepare();
+ FreeTensorIfUnused(m_WeightsTensor);
+ FreeTensorIfUnused(m_BiasesTensor);
+ prepared = true;
+ }
m_FullyConnectedLayer->run();
}