diff options
Diffstat (limited to 'src/backends/reference/workloads/RefComparisonWorkload.cpp')
-rw-r--r-- | src/backends/reference/workloads/RefComparisonWorkload.cpp | 36 |
1 files changed, 27 insertions, 9 deletions
diff --git a/src/backends/reference/workloads/RefComparisonWorkload.cpp b/src/backends/reference/workloads/RefComparisonWorkload.cpp index 52ad9a2879..03df7a4c4a 100644 --- a/src/backends/reference/workloads/RefComparisonWorkload.cpp +++ b/src/backends/reference/workloads/RefComparisonWorkload.cpp @@ -26,9 +26,15 @@ RefComparisonWorkload::RefComparisonWorkload(const ComparisonQueueDescriptor& de void RefComparisonWorkload::PostAllocationConfigure() { - const TensorInfo& inputInfo0 = GetTensorInfo(m_Data.m_Inputs[0]); - const TensorInfo& inputInfo1 = GetTensorInfo(m_Data.m_Inputs[1]); - const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]); + PostAllocationConfigure(m_Data.m_Inputs, m_Data.m_Outputs); +} + +void RefComparisonWorkload::PostAllocationConfigure(std::vector<ITensorHandle*> inputs, + std::vector<ITensorHandle*> outputs) +{ + const TensorInfo& inputInfo0 = GetTensorInfo(inputs[0]); + const TensorInfo& inputInfo1 = GetTensorInfo(inputs[1]); + const TensorInfo& outputInfo = GetTensorInfo(outputs[0]); m_Input0 = MakeDecoder<InType>(inputInfo0); m_Input1 = MakeDecoder<InType>(inputInfo1); @@ -38,19 +44,31 @@ void RefComparisonWorkload::PostAllocationConfigure() void RefComparisonWorkload::Execute() const { + Execute(m_Data.m_Inputs, m_Data.m_Outputs); +} + +void RefComparisonWorkload::ExecuteAsync(WorkingMemDescriptor &workingMemDescriptor) +{ + PostAllocationConfigure(workingMemDescriptor.m_Inputs, workingMemDescriptor.m_Outputs); + + Execute(workingMemDescriptor.m_Inputs, workingMemDescriptor.m_Outputs); +} + +void RefComparisonWorkload::Execute(std::vector<ITensorHandle*> inputs, std::vector<ITensorHandle*> outputs) const +{ ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefComparisonWorkload_Execute"); - const TensorInfo& inputInfo0 = GetTensorInfo(m_Data.m_Inputs[0]); - const TensorInfo& inputInfo1 = GetTensorInfo(m_Data.m_Inputs[1]); - const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]); + const TensorInfo& inputInfo0 = GetTensorInfo(inputs[0]); + const TensorInfo& inputInfo1 = GetTensorInfo(inputs[1]); + const TensorInfo& outputInfo = GetTensorInfo(outputs[0]); const TensorShape& inShape0 = inputInfo0.GetShape(); const TensorShape& inShape1 = inputInfo1.GetShape(); const TensorShape& outShape = outputInfo.GetShape(); - m_Input0->Reset(m_Data.m_Inputs[0]->Map()); - m_Input1->Reset(m_Data.m_Inputs[1]->Map()); - m_Output->Reset(m_Data.m_Outputs[0]->Map()); + m_Input0->Reset(inputs[0]->Map()); + m_Input1->Reset(inputs[1]->Map()); + m_Output->Reset(outputs[0]->Map()); using EqualFunction = ElementwiseBinaryFunction<std::equal_to<InType>>; using GreaterFunction = ElementwiseBinaryFunction<std::greater<InType>>; |