aboutsummaryrefslogtreecommitdiff
path: root/src/backends/neon/workloads/NeonFullyConnectedWorkload.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/neon/workloads/NeonFullyConnectedWorkload.cpp')
-rw-r--r--src/backends/neon/workloads/NeonFullyConnectedWorkload.cpp13
1 files changed, 9 insertions, 4 deletions
diff --git a/src/backends/neon/workloads/NeonFullyConnectedWorkload.cpp b/src/backends/neon/workloads/NeonFullyConnectedWorkload.cpp
index e432a6b833..7395270400 100644
--- a/src/backends/neon/workloads/NeonFullyConnectedWorkload.cpp
+++ b/src/backends/neon/workloads/NeonFullyConnectedWorkload.cpp
@@ -5,10 +5,13 @@
#include "NeonFullyConnectedWorkload.hpp"
+#include "NeonWorkloadUtils.hpp"
#include <aclCommon/ArmComputeTensorUtils.hpp>
#include <aclCommon/ArmComputeUtils.hpp>
#include <backendsCommon/CpuTensorHandle.hpp>
+#include <arm_compute/runtime/NEON/functions/NEFullyConnectedLayer.h>
+
namespace armnn
{
using namespace armcomputetensorutils;
@@ -45,7 +48,6 @@ arm_compute::Status NeonFullyConnectedWorkloadValidate(const TensorInfo& input,
NeonFullyConnectedWorkload::NeonFullyConnectedWorkload(const FullyConnectedQueueDescriptor& descriptor,
const WorkloadInfo& info, std::shared_ptr<arm_compute::MemoryManagerOnDemand>& memoryManager)
: BaseWorkload<FullyConnectedQueueDescriptor>(descriptor, info)
- , m_FullyConnectedLayer(memoryManager)
{
m_Data.ValidateInputsOutputs("NeonFullyConnectedWorkload", 1, 1);
@@ -64,7 +66,10 @@ NeonFullyConnectedWorkload::NeonFullyConnectedWorkload(const FullyConnectedQueue
// Construct
arm_compute::FullyConnectedLayerInfo fc_info;
fc_info.transpose_weights = m_Data.m_Parameters.m_TransposeWeightMatrix;
- m_FullyConnectedLayer.configure(&input, m_WeightsTensor.get(), m_BiasesTensor.get(), &output, fc_info);
+
+ auto layer = std::make_unique<arm_compute::NEFullyConnectedLayer>(memoryManager);
+ layer->configure(&input, m_WeightsTensor.get(), m_BiasesTensor.get(), &output, fc_info);
+ m_FullyConnectedLayer.reset(layer.release());
// Allocate
if (m_Data.m_Weight->GetTensorInfo().GetDataType() == DataType::QuantisedAsymm8)
@@ -90,14 +95,14 @@ NeonFullyConnectedWorkload::NeonFullyConnectedWorkload(const FullyConnectedQueue
// Force Compute Library to perform the necessary copying and reshaping, after which
// delete all the input tensors that will no longer be needed
- m_FullyConnectedLayer.prepare();
+ m_FullyConnectedLayer->prepare();
FreeUnusedTensors();
}
void NeonFullyConnectedWorkload::Execute() const
{
ARMNN_SCOPED_PROFILING_EVENT_NEON("NeonFullyConnectedWorkload_Execute");
- m_FullyConnectedLayer.run();
+ m_FullyConnectedLayer->run();
}
void NeonFullyConnectedWorkload::FreeUnusedTensors()