From 2241d18f16878ddef261eadda9a0a8f0672a60c8 Mon Sep 17 00:00:00 2001 From: Sadik Armagan Date: Fri, 23 Apr 2021 15:20:07 +0100 Subject: IVGCVSW-5416 'Add android-nn-driver support for CAST Signed-off-by: Sadik Armagan Change-Id: I02da912e5e4ca650b367ca40fe3f5ca5baa61cbb --- .../reference/workloads/RefCastWorkload.cpp | 39 +++++++++++++++++----- .../reference/workloads/RefCastWorkload.hpp | 3 ++ 2 files changed, 34 insertions(+), 8 deletions(-) diff --git a/src/backends/reference/workloads/RefCastWorkload.cpp b/src/backends/reference/workloads/RefCastWorkload.cpp index 7080415e5d..8f2a7259f1 100644 --- a/src/backends/reference/workloads/RefCastWorkload.cpp +++ b/src/backends/reference/workloads/RefCastWorkload.cpp @@ -26,15 +26,38 @@ namespace namespace armnn { - void RefCastWorkload::Execute() const - { - ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefCastWorkload_Execute"); - const TensorInfo& inputInfo = GetTensorInfo(m_Data.m_Inputs[0]); - const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]); +void RefCastWorkload::Execute() const +{ + Execute(m_Data.m_Inputs, m_Data.m_Outputs); +} + +void RefCastWorkload::ExecuteAsync(WorkingMemDescriptor &workingMemDescriptor) +{ + Execute(workingMemDescriptor.m_Inputs, workingMemDescriptor.m_Outputs); +} + +void RefCastWorkload::Execute(std::vector inputs, std::vector outputs) const +{ + ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefCastWorkload_Execute"); + + TensorInfo inputTensorInfo(GetTensorInfo(inputs[0])); + TensorInfo outputTensorInfo(GetTensorInfo(outputs[0])); - Cast(*MakeDecoder(inputInfo, m_Data.m_Inputs[0]->Map()), - *MakeEncoder(outputInfo, m_Data.m_Outputs[0]->Map()), - inputInfo.GetNumElements()); + // Quantization info should set to default values. + if (inputTensorInfo.IsQuantized()) + { + inputTensorInfo.SetQuantizationScale(1.0f); + inputTensorInfo.SetQuantizationOffset(0); } + if (outputTensorInfo.IsQuantized()) + { + outputTensorInfo.SetQuantizationScale(1.0f); + outputTensorInfo.SetQuantizationOffset(0); + } + + Cast(*MakeDecoder(inputTensorInfo, inputs[0]->Map()), + *MakeEncoder(outputTensorInfo, outputs[0]->Map()), + inputTensorInfo.GetNumElements()); +} } //namespace armnn \ No newline at end of file diff --git a/src/backends/reference/workloads/RefCastWorkload.hpp b/src/backends/reference/workloads/RefCastWorkload.hpp index 6742ef08ca..870fb410ac 100644 --- a/src/backends/reference/workloads/RefCastWorkload.hpp +++ b/src/backends/reference/workloads/RefCastWorkload.hpp @@ -18,6 +18,9 @@ class RefCastWorkload : public BaseWorkload public: using BaseWorkload::BaseWorkload; void Execute() const override; + void ExecuteAsync(WorkingMemDescriptor& workingMemDescriptor) override; +private: + void Execute(std::vector inputs, std::vector outputs) const; }; } //namespace armnn -- cgit v1.2.1