diff options
Diffstat (limited to 'src/backends/reference/workloads/RefConvolution3dWorkload.cpp')
-rw-r--r-- | src/backends/reference/workloads/RefConvolution3dWorkload.cpp | 33 |
1 files changed, 24 insertions, 9 deletions
diff --git a/src/backends/reference/workloads/RefConvolution3dWorkload.cpp b/src/backends/reference/workloads/RefConvolution3dWorkload.cpp index ea425daec9..afab88f0a8 100644 --- a/src/backends/reference/workloads/RefConvolution3dWorkload.cpp +++ b/src/backends/reference/workloads/RefConvolution3dWorkload.cpp @@ -19,10 +19,10 @@ RefConvolution3dWorkload::RefConvolution3dWorkload( WorkloadInfo detailsInfo; detailsInfo.m_InputTensorInfos = info.m_InputTensorInfos; detailsInfo.m_OutputTensorInfos = info.m_OutputTensorInfos; - detailsInfo.m_WeightsTensorInfo = armnn::Optional<armnn::TensorInfo>(descriptor.m_Weight->GetTensorInfo()); + detailsInfo.m_WeightsTensorInfo = armnn::Optional<armnn::TensorInfo>(info.m_InputTensorInfos[1]); if (descriptor.m_Parameters.m_BiasEnabled) { - detailsInfo.m_BiasTensorInfo = armnn::Optional<armnn::TensorInfo>(descriptor.m_Bias->GetTensorInfo()); + detailsInfo.m_BiasTensorInfo = armnn::Optional<armnn::TensorInfo>(info.m_InputTensorInfos[2]); } // Report Profiling Details @@ -30,18 +30,25 @@ RefConvolution3dWorkload::RefConvolution3dWorkload( descriptor.m_Parameters, detailsInfo, this->GetGuid()); +} - m_Weight = std::make_unique<ScopedTensorHandle>(*( descriptor.m_Weight )); - const TensorInfo& rFilterInfo = m_Weight->GetTensorInfo(); +void RefConvolution3dWorkload::PostAllocationConfigure() +{ + PostAllocationConfigure(m_Data.m_Inputs, m_Data.m_Outputs); +} +void RefConvolution3dWorkload::PostAllocationConfigure(std::vector<ITensorHandle*> inputs, + std::vector<ITensorHandle*> outputs) +{ + IgnoreUnused(outputs); + const TensorInfo& rFilterInfo = GetTensorInfo(inputs[1]); m_FilterShape = rFilterInfo.GetShape(); - m_FilterDecoder = MakeDecoder<float>(rFilterInfo, m_Weight.get()->Map(true)); + m_FilterDecoder = MakeDecoder<float>(rFilterInfo); - if ( descriptor.m_Parameters.m_BiasEnabled ) + if (m_Data.m_Parameters.m_BiasEnabled) { - m_Bias = std::make_unique<ScopedTensorHandle>(*( descriptor.m_Bias )); - const TensorInfo& biasInfo = m_Bias->GetTensorInfo(); - m_BiasDecoder = MakeDecoder<float>(biasInfo, m_Bias->Map(true)); + const TensorInfo& biasInfo = GetTensorInfo(inputs[2]); + m_BiasDecoder = MakeDecoder<float>(biasInfo); } } @@ -52,6 +59,8 @@ void RefConvolution3dWorkload::Execute() const void RefConvolution3dWorkload::ExecuteAsync(WorkingMemDescriptor& workingMemDescriptor) { + PostAllocationConfigure(workingMemDescriptor.m_Inputs, workingMemDescriptor.m_Outputs); + Execute(workingMemDescriptor.m_Inputs, workingMemDescriptor.m_Outputs); } @@ -65,6 +74,12 @@ void RefConvolution3dWorkload::Execute(std::vector<ITensorHandle*> inputs, std:: const TensorShape& inputShape = GetTensorInfo(inputs[0]).GetShape(); const TensorShape& outputShape = GetTensorInfo(outputs[0]).GetShape(); + m_FilterDecoder->Reset(inputs[1]->Map()); + if (m_Data.m_Parameters.m_BiasEnabled) + { + m_BiasDecoder->Reset(inputs[2]->Map()); + } + Convolve3d(inputShape, *inputDecoder, outputShape, *outputEncoder, m_FilterShape, *m_FilterDecoder, m_Data.m_Parameters.m_BiasEnabled, m_BiasDecoder.get(), m_Data.m_Parameters.m_DataLayout, |