diff options
author | Finn Williams <Finn.Williams@arm.com> | 2021-04-07 10:23:21 +0100 |
---|---|---|
committer | Finn Williams <Finn.Williams@arm.com> | 2021-04-14 15:18:38 +0100 |
commit | b8181f72b8c7c9132373dbcf7f8709ec2c0f23c0 (patch) | |
tree | 04cc91a6efb7e2601f80e4213a747938165b7184 /src/backends/reference/workloads/RefFullyConnectedWorkload.cpp | |
parent | b898222a8856475f0217be5e78b4816aa1914f15 (diff) | |
download | armnn-b8181f72b8c7c9132373dbcf7f8709ec2c0f23c0.tar.gz |
IVGCVSW-5787 Add/Update Execute() implementations in RefActivationWorkload
* Added multithreaded StridedSliceEndToEndTest
Signed-off-by: Finn Williams <Finn.Williams@arm.com>
Change-Id: I4579db7b5959e0a22256f1bda00238c22e611dec
Diffstat (limited to 'src/backends/reference/workloads/RefFullyConnectedWorkload.cpp')
-rw-r--r-- | src/backends/reference/workloads/RefFullyConnectedWorkload.cpp | 41 |
1 files changed, 29 insertions, 12 deletions
diff --git a/src/backends/reference/workloads/RefFullyConnectedWorkload.cpp b/src/backends/reference/workloads/RefFullyConnectedWorkload.cpp index 49e105f206..deb56d4c6b 100644 --- a/src/backends/reference/workloads/RefFullyConnectedWorkload.cpp +++ b/src/backends/reference/workloads/RefFullyConnectedWorkload.cpp @@ -34,28 +34,32 @@ RefFullyConnectedWorkload::RefFullyConnectedWorkload( void RefFullyConnectedWorkload::PostAllocationConfigure() { - const TensorInfo& inputInfo = GetTensorInfo(m_Data.m_Inputs[0]); + PostAllocationConfigure(m_Data.m_Inputs, m_Data.m_Outputs); +} + +void RefFullyConnectedWorkload::PostAllocationConfigure(std::vector<ITensorHandle*> inputs, + std::vector<ITensorHandle*> outputs) +{ + const TensorInfo& inputInfo = GetTensorInfo(inputs[0]); ARMNN_ASSERT(inputInfo.GetNumDimensions() > 1); m_InputShape = inputInfo.GetShape(); - m_InputDecoder = MakeDecoder<float>(inputInfo); if (!m_Data.m_Parameters.m_ConstantWeights) { - const TensorInfo& rWeightInfo = GetTensorInfo(m_Data.m_Inputs[1]); + const TensorInfo& rWeightInfo = GetTensorInfo(inputs[1]); ARMNN_ASSERT(inputInfo.GetNumDimensions() > 1); m_WeightShape = rWeightInfo.GetShape(); m_WeightDecoder = MakeDecoder<float>(rWeightInfo); if (m_Data.m_Parameters.m_BiasEnabled) { - const TensorInfo& biasInfo = GetTensorInfo(m_Data.m_Inputs[2]); + const TensorInfo& biasInfo = GetTensorInfo(inputs[2]); m_BiasDecoder = MakeDecoder<float>(biasInfo); } } - const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]); + const TensorInfo& outputInfo = GetTensorInfo(outputs[0]); m_OutputShape = outputInfo.GetShape(); - m_OutputEncoder = MakeEncoder<float>(outputInfo); m_NumActivations = 1; // Total number of activations in the input. for (unsigned int i = 1; i < inputInfo.GetNumDimensions(); i++) @@ -66,23 +70,36 @@ void RefFullyConnectedWorkload::PostAllocationConfigure() void RefFullyConnectedWorkload::Execute() const { + Execute(m_Data.m_Inputs, m_Data.m_Outputs); +} + +void RefFullyConnectedWorkload::ExecuteAsync(WorkingMemDescriptor &workingMemDescriptor) +{ + PostAllocationConfigure(workingMemDescriptor.m_Inputs, workingMemDescriptor.m_Outputs); + + Execute(workingMemDescriptor.m_Inputs, workingMemDescriptor.m_Outputs); +} + +void RefFullyConnectedWorkload::Execute(std::vector<ITensorHandle*> inputs, std::vector<ITensorHandle*> outputs) const +{ ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefFullyConnectedWorkload_Execute"); - m_InputDecoder->Reset(m_Data.m_Inputs[0]->Map()); + std::unique_ptr<Decoder<float>> inputDecoder = MakeDecoder<float>(GetTensorInfo(inputs[0]), inputs[0]->Map()); + std::unique_ptr<Encoder<float>> OutputEncoder = MakeEncoder<float>(GetTensorInfo(outputs[0]), outputs[0]->Map()); + if (!m_Data.m_Parameters.m_ConstantWeights) { - m_WeightDecoder->Reset(m_Data.m_Inputs[1]->Map()); + m_WeightDecoder->Reset(inputs[1]->Map()); if (m_Data.m_Parameters.m_BiasEnabled) { - m_BiasDecoder->Reset(m_Data.m_Inputs[2]->Map()); + m_BiasDecoder->Reset(inputs[2]->Map()); } } - m_OutputEncoder->Reset(m_Data.m_Outputs[0]->Map()); FullyConnected(m_InputShape, - *m_InputDecoder, + *inputDecoder, m_OutputShape, - *m_OutputEncoder, + *OutputEncoder, m_WeightShape, *m_WeightDecoder, *m_BiasDecoder, |