diff options
Diffstat (limited to 'src/backends/reference/workloads/RefFullyConnectedWorkload.cpp')
-rw-r--r-- | src/backends/reference/workloads/RefFullyConnectedWorkload.cpp | 41 |
1 files changed, 29 insertions, 12 deletions
diff --git a/src/backends/reference/workloads/RefFullyConnectedWorkload.cpp b/src/backends/reference/workloads/RefFullyConnectedWorkload.cpp index 49e105f206..deb56d4c6b 100644 --- a/src/backends/reference/workloads/RefFullyConnectedWorkload.cpp +++ b/src/backends/reference/workloads/RefFullyConnectedWorkload.cpp @@ -34,28 +34,32 @@ RefFullyConnectedWorkload::RefFullyConnectedWorkload( void RefFullyConnectedWorkload::PostAllocationConfigure() { - const TensorInfo& inputInfo = GetTensorInfo(m_Data.m_Inputs[0]); + PostAllocationConfigure(m_Data.m_Inputs, m_Data.m_Outputs); +} + +void RefFullyConnectedWorkload::PostAllocationConfigure(std::vector<ITensorHandle*> inputs, + std::vector<ITensorHandle*> outputs) +{ + const TensorInfo& inputInfo = GetTensorInfo(inputs[0]); ARMNN_ASSERT(inputInfo.GetNumDimensions() > 1); m_InputShape = inputInfo.GetShape(); - m_InputDecoder = MakeDecoder<float>(inputInfo); if (!m_Data.m_Parameters.m_ConstantWeights) { - const TensorInfo& rWeightInfo = GetTensorInfo(m_Data.m_Inputs[1]); + const TensorInfo& rWeightInfo = GetTensorInfo(inputs[1]); ARMNN_ASSERT(inputInfo.GetNumDimensions() > 1); m_WeightShape = rWeightInfo.GetShape(); m_WeightDecoder = MakeDecoder<float>(rWeightInfo); if (m_Data.m_Parameters.m_BiasEnabled) { - const TensorInfo& biasInfo = GetTensorInfo(m_Data.m_Inputs[2]); + const TensorInfo& biasInfo = GetTensorInfo(inputs[2]); m_BiasDecoder = MakeDecoder<float>(biasInfo); } } - const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]); + const TensorInfo& outputInfo = GetTensorInfo(outputs[0]); m_OutputShape = outputInfo.GetShape(); - m_OutputEncoder = MakeEncoder<float>(outputInfo); m_NumActivations = 1; // Total number of activations in the input. for (unsigned int i = 1; i < inputInfo.GetNumDimensions(); i++) @@ -66,23 +70,36 @@ void RefFullyConnectedWorkload::PostAllocationConfigure() void RefFullyConnectedWorkload::Execute() const { + Execute(m_Data.m_Inputs, m_Data.m_Outputs); +} + +void RefFullyConnectedWorkload::ExecuteAsync(WorkingMemDescriptor &workingMemDescriptor) +{ + PostAllocationConfigure(workingMemDescriptor.m_Inputs, workingMemDescriptor.m_Outputs); + + Execute(workingMemDescriptor.m_Inputs, workingMemDescriptor.m_Outputs); +} + +void RefFullyConnectedWorkload::Execute(std::vector<ITensorHandle*> inputs, std::vector<ITensorHandle*> outputs) const +{ ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefFullyConnectedWorkload_Execute"); - m_InputDecoder->Reset(m_Data.m_Inputs[0]->Map()); + std::unique_ptr<Decoder<float>> inputDecoder = MakeDecoder<float>(GetTensorInfo(inputs[0]), inputs[0]->Map()); + std::unique_ptr<Encoder<float>> OutputEncoder = MakeEncoder<float>(GetTensorInfo(outputs[0]), outputs[0]->Map()); + if (!m_Data.m_Parameters.m_ConstantWeights) { - m_WeightDecoder->Reset(m_Data.m_Inputs[1]->Map()); + m_WeightDecoder->Reset(inputs[1]->Map()); if (m_Data.m_Parameters.m_BiasEnabled) { - m_BiasDecoder->Reset(m_Data.m_Inputs[2]->Map()); + m_BiasDecoder->Reset(inputs[2]->Map()); } } - m_OutputEncoder->Reset(m_Data.m_Outputs[0]->Map()); FullyConnected(m_InputShape, - *m_InputDecoder, + *inputDecoder, m_OutputShape, - *m_OutputEncoder, + *OutputEncoder, m_WeightShape, *m_WeightDecoder, *m_BiasDecoder, |