diff options
Diffstat (limited to 'src/backends/reference/workloads/RefElementwiseWorkload.cpp')
-rw-r--r-- | src/backends/reference/workloads/RefElementwiseWorkload.cpp | 36 |
1 files changed, 19 insertions, 17 deletions
diff --git a/src/backends/reference/workloads/RefElementwiseWorkload.cpp b/src/backends/reference/workloads/RefElementwiseWorkload.cpp index 60acbd6252..dd7d325ca5 100644 --- a/src/backends/reference/workloads/RefElementwiseWorkload.cpp +++ b/src/backends/reference/workloads/RefElementwiseWorkload.cpp @@ -26,39 +26,41 @@ RefElementwiseWorkload<Functor, ParentDescriptor, DebugString>::RefElementwiseWo } template <typename Functor, typename ParentDescriptor, typename armnn::StringMapping::Id DebugString> -void RefElementwiseWorkload<Functor, ParentDescriptor, DebugString>::PostAllocationConfigure() +void RefElementwiseWorkload<Functor, ParentDescriptor, DebugString>::Execute() const { - const TensorInfo& inputInfo0 = GetTensorInfo(m_Data.m_Inputs[0]); - const TensorInfo& inputInfo1 = GetTensorInfo(m_Data.m_Inputs[1]); - const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]); + Execute(m_Data.m_Inputs, m_Data.m_Outputs); +} - m_Input0 = MakeDecoder<InType>(inputInfo0); - m_Input1 = MakeDecoder<InType>(inputInfo1); - m_Output = MakeEncoder<OutType>(outputInfo); +template <typename Functor, typename ParentDescriptor, typename armnn::StringMapping::Id DebugString> +void RefElementwiseWorkload<Functor, ParentDescriptor, DebugString>::ExecuteAsync( + WorkingMemDescriptor &workingMemDescriptor) +{ + Execute(workingMemDescriptor.m_Inputs, workingMemDescriptor.m_Outputs); } template <typename Functor, typename ParentDescriptor, typename armnn::StringMapping::Id DebugString> -void RefElementwiseWorkload<Functor, ParentDescriptor, DebugString>::Execute() const +void RefElementwiseWorkload<Functor, ParentDescriptor, DebugString>::Execute( + std::vector<ITensorHandle*> inputs, std::vector<ITensorHandle*> outputs) const { ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, StringMapping::Instance().Get(DebugString)); - const TensorInfo& inputInfo0 = GetTensorInfo(m_Data.m_Inputs[0]); - const TensorInfo& inputInfo1 = GetTensorInfo(m_Data.m_Inputs[1]); - const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]); + const TensorInfo& inputInfo0 = GetTensorInfo(inputs[0]); + const TensorInfo& inputInfo1 = GetTensorInfo(inputs[1]); + const TensorInfo& outputInfo = GetTensorInfo(outputs[0]); const TensorShape& inShape0 = inputInfo0.GetShape(); const TensorShape& inShape1 = inputInfo1.GetShape(); const TensorShape& outShape = outputInfo.GetShape(); - m_Input0->Reset(m_Data.m_Inputs[0]->Map()); - m_Input1->Reset(m_Data.m_Inputs[1]->Map()); - m_Output->Reset(m_Data.m_Outputs[0]->Map()); + std::unique_ptr<Decoder<InType>> input0 = MakeDecoder<InType>(inputInfo0, inputs[0]->Map()); + std::unique_ptr<Decoder<InType>> input1 = MakeDecoder<InType>(inputInfo1, inputs[1]->Map()); + std::unique_ptr<Encoder<OutType>> output= MakeEncoder<OutType>(outputInfo, outputs[0]->Map()); ElementwiseBinaryFunction<Functor>(inShape0, inShape1, outShape, - *m_Input0, - *m_Input1, - *m_Output); + *input0, + *input1, + *output); } } //namespace armnn |