diff options
Diffstat (limited to 'src/backends/neon/workloads/NeonFullyConnectedWorkload.cpp')
-rw-r--r-- | src/backends/neon/workloads/NeonFullyConnectedWorkload.cpp | 13 |
1 files changed, 9 insertions, 4 deletions
diff --git a/src/backends/neon/workloads/NeonFullyConnectedWorkload.cpp b/src/backends/neon/workloads/NeonFullyConnectedWorkload.cpp index e432a6b833..7395270400 100644 --- a/src/backends/neon/workloads/NeonFullyConnectedWorkload.cpp +++ b/src/backends/neon/workloads/NeonFullyConnectedWorkload.cpp @@ -5,10 +5,13 @@ #include "NeonFullyConnectedWorkload.hpp" +#include "NeonWorkloadUtils.hpp" #include <aclCommon/ArmComputeTensorUtils.hpp> #include <aclCommon/ArmComputeUtils.hpp> #include <backendsCommon/CpuTensorHandle.hpp> +#include <arm_compute/runtime/NEON/functions/NEFullyConnectedLayer.h> + namespace armnn { using namespace armcomputetensorutils; @@ -45,7 +48,6 @@ arm_compute::Status NeonFullyConnectedWorkloadValidate(const TensorInfo& input, NeonFullyConnectedWorkload::NeonFullyConnectedWorkload(const FullyConnectedQueueDescriptor& descriptor, const WorkloadInfo& info, std::shared_ptr<arm_compute::MemoryManagerOnDemand>& memoryManager) : BaseWorkload<FullyConnectedQueueDescriptor>(descriptor, info) - , m_FullyConnectedLayer(memoryManager) { m_Data.ValidateInputsOutputs("NeonFullyConnectedWorkload", 1, 1); @@ -64,7 +66,10 @@ NeonFullyConnectedWorkload::NeonFullyConnectedWorkload(const FullyConnectedQueue // Construct arm_compute::FullyConnectedLayerInfo fc_info; fc_info.transpose_weights = m_Data.m_Parameters.m_TransposeWeightMatrix; - m_FullyConnectedLayer.configure(&input, m_WeightsTensor.get(), m_BiasesTensor.get(), &output, fc_info); + + auto layer = std::make_unique<arm_compute::NEFullyConnectedLayer>(memoryManager); + layer->configure(&input, m_WeightsTensor.get(), m_BiasesTensor.get(), &output, fc_info); + m_FullyConnectedLayer.reset(layer.release()); // Allocate if (m_Data.m_Weight->GetTensorInfo().GetDataType() == DataType::QuantisedAsymm8) @@ -90,14 +95,14 @@ NeonFullyConnectedWorkload::NeonFullyConnectedWorkload(const FullyConnectedQueue // Force Compute Library to perform the necessary copying and reshaping, after which // delete all the input tensors that will no longer be needed - m_FullyConnectedLayer.prepare(); + m_FullyConnectedLayer->prepare(); FreeUnusedTensors(); } void NeonFullyConnectedWorkload::Execute() const { ARMNN_SCOPED_PROFILING_EVENT_NEON("NeonFullyConnectedWorkload_Execute"); - m_FullyConnectedLayer.run(); + m_FullyConnectedLayer->run(); } void NeonFullyConnectedWorkload::FreeUnusedTensors() |