diff options
Diffstat (limited to 'src/backends/reference/workloads/RefRankWorkload.hpp')
-rw-r--r-- | src/backends/reference/workloads/RefRankWorkload.hpp | 15 |
1 files changed, 13 insertions, 2 deletions
diff --git a/src/backends/reference/workloads/RefRankWorkload.hpp b/src/backends/reference/workloads/RefRankWorkload.hpp index 660db6b8db..237ae999ce 100644 --- a/src/backends/reference/workloads/RefRankWorkload.hpp +++ b/src/backends/reference/workloads/RefRankWorkload.hpp @@ -19,10 +19,21 @@ public: using BaseWorkload<RankQueueDescriptor>::BaseWorkload; virtual void Execute() const override { - const int32_t rank = static_cast<int32_t>(GetTensorInfo(m_Data.m_Inputs[0]).GetNumDimensions()); + Execute(m_Data.m_Inputs, m_Data.m_Outputs); + + } + void ExecuteAsync(WorkingMemDescriptor& workingMemDescriptor) override + { + Execute(workingMemDescriptor.m_Inputs, workingMemDescriptor.m_Outputs); + } + +private: + void Execute(std::vector<ITensorHandle*> inputs, std::vector<ITensorHandle*> outputs) const + { + const int32_t rank = static_cast<int32_t>(GetTensorInfo(inputs[0]).GetNumDimensions()); std::memcpy(GetOutputTensorData<void>(0, m_Data), &rank, sizeof(int32_t)); - m_Data.m_Outputs[0]->Unmap(); + outputs[0]->Unmap(); } }; |