diff options
Diffstat (limited to 'src/armnn/backends/NeonWorkloadUtils.cpp')
-rw-r--r-- | src/armnn/backends/NeonWorkloadUtils.cpp | 21 |
1 files changed, 20 insertions, 1 deletions
diff --git a/src/armnn/backends/NeonWorkloadUtils.cpp b/src/armnn/backends/NeonWorkloadUtils.cpp index e807d23d6c..07e5d510eb 100644 --- a/src/armnn/backends/NeonWorkloadUtils.cpp +++ b/src/armnn/backends/NeonWorkloadUtils.cpp @@ -20,13 +20,14 @@ #include "NeonLayerSupport.hpp" #include "../../../include/armnn/Types.hpp" +#include "Half.hpp" using namespace armnn::armcomputetensorutils; namespace armnn { -// Allocate a tensor and copy the contents in data to the tensor contents +// Allocates a tensor and copy the contents in data to the tensor contents. template<typename T> void InitialiseArmComputeTensorData(arm_compute::Tensor& tensor, const T* data) { @@ -34,8 +35,26 @@ void InitialiseArmComputeTensorData(arm_compute::Tensor& tensor, const T* data) CopyArmComputeITensorData(data, tensor); } +template void InitialiseArmComputeTensorData(arm_compute::Tensor& tensor, const Half* data); template void InitialiseArmComputeTensorData(arm_compute::Tensor& tensor, const float* data); template void InitialiseArmComputeTensorData(arm_compute::Tensor& tensor, const uint8_t* data); template void InitialiseArmComputeTensorData(arm_compute::Tensor& tensor, const int32_t* data); +void InitializeArmComputeTensorDataForFloatTypes(arm_compute::Tensor& tensor, + const ConstCpuTensorHandle* handle) +{ + BOOST_ASSERT(handle); + switch(handle->GetTensorInfo().GetDataType()) + { + case DataType::Float16: + InitialiseArmComputeTensorData(tensor, handle->GetConstTensor<Half>()); + break; + case DataType::Float32: + InitialiseArmComputeTensorData(tensor, handle->GetConstTensor<float>()); + break; + default: + BOOST_ASSERT_MSG(false, "Unexpected floating point type."); + } +}; + } //namespace armnn |