aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSadik Armagan <sadik.armagan@arm.com>2021-04-23 15:20:07 +0100
committerSadik Armagan <sadik.armagan@arm.com>2021-04-28 09:00:47 +0000
commit2241d18f16878ddef261eadda9a0a8f0672a60c8 (patch)
tree6ec9cfc5dae3f4435c00564e1c8c24054dc7f98c
parent5d955cf70ae0c5558d4f431f0fc6bd4552cd43a5 (diff)
downloadarmnn-2241d18f16878ddef261eadda9a0a8f0672a60c8.tar.gz
IVGCVSW-5416 'Add android-nn-driver support for CAST
Signed-off-by: Sadik Armagan <sadik.armagan@arm.com> Change-Id: I02da912e5e4ca650b367ca40fe3f5ca5baa61cbb
-rw-r--r--src/backends/reference/workloads/RefCastWorkload.cpp39
-rw-r--r--src/backends/reference/workloads/RefCastWorkload.hpp3
2 files changed, 34 insertions, 8 deletions
diff --git a/src/backends/reference/workloads/RefCastWorkload.cpp b/src/backends/reference/workloads/RefCastWorkload.cpp
index 7080415e5d..8f2a7259f1 100644
--- a/src/backends/reference/workloads/RefCastWorkload.cpp
+++ b/src/backends/reference/workloads/RefCastWorkload.cpp
@@ -26,15 +26,38 @@ namespace
namespace armnn
{
- void RefCastWorkload::Execute() const
- {
- ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefCastWorkload_Execute");
- const TensorInfo& inputInfo = GetTensorInfo(m_Data.m_Inputs[0]);
- const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]);
+void RefCastWorkload::Execute() const
+{
+ Execute(m_Data.m_Inputs, m_Data.m_Outputs);
+}
+
+void RefCastWorkload::ExecuteAsync(WorkingMemDescriptor &workingMemDescriptor)
+{
+ Execute(workingMemDescriptor.m_Inputs, workingMemDescriptor.m_Outputs);
+}
+
+void RefCastWorkload::Execute(std::vector<ITensorHandle*> inputs, std::vector<ITensorHandle*> outputs) const
+{
+ ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefCastWorkload_Execute");
+
+ TensorInfo inputTensorInfo(GetTensorInfo(inputs[0]));
+ TensorInfo outputTensorInfo(GetTensorInfo(outputs[0]));
- Cast(*MakeDecoder<float>(inputInfo, m_Data.m_Inputs[0]->Map()),
- *MakeEncoder<float>(outputInfo, m_Data.m_Outputs[0]->Map()),
- inputInfo.GetNumElements());
+ // Quantization info should set to default values.
+ if (inputTensorInfo.IsQuantized())
+ {
+ inputTensorInfo.SetQuantizationScale(1.0f);
+ inputTensorInfo.SetQuantizationOffset(0);
}
+ if (outputTensorInfo.IsQuantized())
+ {
+ outputTensorInfo.SetQuantizationScale(1.0f);
+ outputTensorInfo.SetQuantizationOffset(0);
+ }
+
+ Cast(*MakeDecoder<float>(inputTensorInfo, inputs[0]->Map()),
+ *MakeEncoder<float>(outputTensorInfo, outputs[0]->Map()),
+ inputTensorInfo.GetNumElements());
+}
} //namespace armnn \ No newline at end of file
diff --git a/src/backends/reference/workloads/RefCastWorkload.hpp b/src/backends/reference/workloads/RefCastWorkload.hpp
index 6742ef08ca..870fb410ac 100644
--- a/src/backends/reference/workloads/RefCastWorkload.hpp
+++ b/src/backends/reference/workloads/RefCastWorkload.hpp
@@ -18,6 +18,9 @@ class RefCastWorkload : public BaseWorkload<CastQueueDescriptor>
public:
using BaseWorkload<CastQueueDescriptor>::BaseWorkload;
void Execute() const override;
+ void ExecuteAsync(WorkingMemDescriptor& workingMemDescriptor) override;
+private:
+ void Execute(std::vector<ITensorHandle*> inputs, std::vector<ITensorHandle*> outputs) const;
};
} //namespace armnn