diff options
Diffstat (limited to 'src/backends/reference/workloads/RefWorkloadUtils.hpp')
-rw-r--r-- | src/backends/reference/workloads/RefWorkloadUtils.hpp | 64 |
1 files changed, 10 insertions, 54 deletions
diff --git a/src/backends/reference/workloads/RefWorkloadUtils.hpp b/src/backends/reference/workloads/RefWorkloadUtils.hpp index ce796160f2..c3260c8142 100644 --- a/src/backends/reference/workloads/RefWorkloadUtils.hpp +++ b/src/backends/reference/workloads/RefWorkloadUtils.hpp @@ -9,8 +9,10 @@ #include <armnn/Tensor.hpp> #include <armnn/Types.hpp> -#include <Half.hpp> +#include <reference/RefTensorHandle.hpp> + +#include <Half.hpp> #include <boost/polymorphic_cast.hpp> namespace armnn @@ -22,41 +24,24 @@ namespace armnn inline const TensorInfo& GetTensorInfo(const ITensorHandle* tensorHandle) { - // We know that reference workloads use CpuTensorHandles only, so this cast is legitimate. - const ConstCpuTensorHandle* cpuTensorHandle = - boost::polymorphic_downcast<const ConstCpuTensorHandle*>(tensorHandle); - return cpuTensorHandle->GetTensorInfo(); -} - -template <typename DataType> -inline const DataType* GetConstCpuData(const ITensorHandle* tensorHandle) -{ - // We know that reference workloads use (Const)CpuTensorHandles only, so this cast is legitimate. - const ConstCpuTensorHandle* cpuTensorHandle = - boost::polymorphic_downcast<const ConstCpuTensorHandle*>(tensorHandle); - return cpuTensorHandle->GetConstTensor<DataType>(); + // We know that reference workloads use RefTensorHandles for inputs and outputs + const RefTensorHandle* refTensorHandle = + boost::polymorphic_downcast<const RefTensorHandle*>(tensorHandle); + return refTensorHandle->GetTensorInfo(); } -template <typename DataType> -inline DataType* GetCpuData(const ITensorHandle* tensorHandle) -{ - // We know that reference workloads use CpuTensorHandles only, so this cast is legitimate. - const CpuTensorHandle* cpuTensorHandle = boost::polymorphic_downcast<const CpuTensorHandle*>(tensorHandle); - return cpuTensorHandle->GetTensor<DataType>(); -}; - template <typename DataType, typename PayloadType> const DataType* GetInputTensorData(unsigned int idx, const PayloadType& data) { const ITensorHandle* tensorHandle = data.m_Inputs[idx]; - return GetConstCpuData<DataType>(tensorHandle); + return reinterpret_cast<const DataType*>(tensorHandle->Map()); } template <typename DataType, typename PayloadType> DataType* GetOutputTensorData(unsigned int idx, const PayloadType& data) { - const ITensorHandle* tensorHandle = data.m_Outputs[idx]; - return GetCpuData<DataType>(tensorHandle); + ITensorHandle* tensorHandle = data.m_Outputs[idx]; + return reinterpret_cast<DataType*>(tensorHandle->Map()); } template <typename PayloadType> @@ -87,35 +72,6 @@ Half* GetOutputTensorDataHalf(unsigned int idx, const PayloadType& data) /// u8 helpers //////////////////////////////////////////// -inline const uint8_t* GetConstCpuU8Data(const ITensorHandle* tensorHandle) -{ - // We know that reference workloads use (Const)CpuTensorHandles only, so this cast is legitimate. - const ConstCpuTensorHandle* cpuTensorHandle = - boost::polymorphic_downcast<const ConstCpuTensorHandle*>(tensorHandle); - return cpuTensorHandle->GetConstTensor<uint8_t>(); -}; - -inline uint8_t* GetCpuU8Data(const ITensorHandle* tensorHandle) -{ - // We know that reference workloads use CpuTensorHandles only, so this cast is legitimate. - const CpuTensorHandle* cpuTensorHandle = boost::polymorphic_downcast<const CpuTensorHandle*>(tensorHandle); - return cpuTensorHandle->GetTensor<uint8_t>(); -}; - -template <typename PayloadType> -const uint8_t* GetInputTensorDataU8(unsigned int idx, const PayloadType& data) -{ - const ITensorHandle* tensorHandle = data.m_Inputs[idx]; - return GetConstCpuU8Data(tensorHandle); -} - -template <typename PayloadType> -uint8_t* GetOutputTensorDataU8(unsigned int idx, const PayloadType& data) -{ - const ITensorHandle* tensorHandle = data.m_Outputs[idx]; - return GetCpuU8Data(tensorHandle); -} - template<typename T> std::vector<float> Dequantize(const T* quant, const TensorInfo& info) { |