diff options
Diffstat (limited to 'src/backends/reference')
-rw-r--r-- | src/backends/reference/workloads/RefCastWorkload.cpp | 39 | ||||
-rw-r--r-- | src/backends/reference/workloads/RefCastWorkload.hpp | 3 |
2 files changed, 34 insertions, 8 deletions
diff --git a/src/backends/reference/workloads/RefCastWorkload.cpp b/src/backends/reference/workloads/RefCastWorkload.cpp index 7080415e5d..8f2a7259f1 100644 --- a/src/backends/reference/workloads/RefCastWorkload.cpp +++ b/src/backends/reference/workloads/RefCastWorkload.cpp @@ -26,15 +26,38 @@ namespace namespace armnn { - void RefCastWorkload::Execute() const - { - ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefCastWorkload_Execute"); - const TensorInfo& inputInfo = GetTensorInfo(m_Data.m_Inputs[0]); - const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]); +void RefCastWorkload::Execute() const +{ + Execute(m_Data.m_Inputs, m_Data.m_Outputs); +} + +void RefCastWorkload::ExecuteAsync(WorkingMemDescriptor &workingMemDescriptor) +{ + Execute(workingMemDescriptor.m_Inputs, workingMemDescriptor.m_Outputs); +} + +void RefCastWorkload::Execute(std::vector<ITensorHandle*> inputs, std::vector<ITensorHandle*> outputs) const +{ + ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefCastWorkload_Execute"); + + TensorInfo inputTensorInfo(GetTensorInfo(inputs[0])); + TensorInfo outputTensorInfo(GetTensorInfo(outputs[0])); - Cast(*MakeDecoder<float>(inputInfo, m_Data.m_Inputs[0]->Map()), - *MakeEncoder<float>(outputInfo, m_Data.m_Outputs[0]->Map()), - inputInfo.GetNumElements()); + // Quantization info should set to default values. + if (inputTensorInfo.IsQuantized()) + { + inputTensorInfo.SetQuantizationScale(1.0f); + inputTensorInfo.SetQuantizationOffset(0); } + if (outputTensorInfo.IsQuantized()) + { + outputTensorInfo.SetQuantizationScale(1.0f); + outputTensorInfo.SetQuantizationOffset(0); + } + + Cast(*MakeDecoder<float>(inputTensorInfo, inputs[0]->Map()), + *MakeEncoder<float>(outputTensorInfo, outputs[0]->Map()), + inputTensorInfo.GetNumElements()); +} } //namespace armnn
\ No newline at end of file diff --git a/src/backends/reference/workloads/RefCastWorkload.hpp b/src/backends/reference/workloads/RefCastWorkload.hpp index 6742ef08ca..870fb410ac 100644 --- a/src/backends/reference/workloads/RefCastWorkload.hpp +++ b/src/backends/reference/workloads/RefCastWorkload.hpp @@ -18,6 +18,9 @@ class RefCastWorkload : public BaseWorkload<CastQueueDescriptor> public: using BaseWorkload<CastQueueDescriptor>::BaseWorkload; void Execute() const override; + void ExecuteAsync(WorkingMemDescriptor& workingMemDescriptor) override; +private: + void Execute(std::vector<ITensorHandle*> inputs, std::vector<ITensorHandle*> outputs) const; }; } //namespace armnn |