aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/backends/NeonWorkloadUtils.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/backends/NeonWorkloadUtils.cpp')
-rw-r--r--src/armnn/backends/NeonWorkloadUtils.cpp21
1 files changed, 20 insertions, 1 deletions
diff --git a/src/armnn/backends/NeonWorkloadUtils.cpp b/src/armnn/backends/NeonWorkloadUtils.cpp
index e807d23d6c..07e5d510eb 100644
--- a/src/armnn/backends/NeonWorkloadUtils.cpp
+++ b/src/armnn/backends/NeonWorkloadUtils.cpp
@@ -20,13 +20,14 @@
#include "NeonLayerSupport.hpp"
#include "../../../include/armnn/Types.hpp"
+#include "Half.hpp"
using namespace armnn::armcomputetensorutils;
namespace armnn
{
-// Allocate a tensor and copy the contents in data to the tensor contents
+// Allocates a tensor and copy the contents in data to the tensor contents.
template<typename T>
void InitialiseArmComputeTensorData(arm_compute::Tensor& tensor, const T* data)
{
@@ -34,8 +35,26 @@ void InitialiseArmComputeTensorData(arm_compute::Tensor& tensor, const T* data)
CopyArmComputeITensorData(data, tensor);
}
+template void InitialiseArmComputeTensorData(arm_compute::Tensor& tensor, const Half* data);
template void InitialiseArmComputeTensorData(arm_compute::Tensor& tensor, const float* data);
template void InitialiseArmComputeTensorData(arm_compute::Tensor& tensor, const uint8_t* data);
template void InitialiseArmComputeTensorData(arm_compute::Tensor& tensor, const int32_t* data);
+void InitializeArmComputeTensorDataForFloatTypes(arm_compute::Tensor& tensor,
+ const ConstCpuTensorHandle* handle)
+{
+ BOOST_ASSERT(handle);
+ switch(handle->GetTensorInfo().GetDataType())
+ {
+ case DataType::Float16:
+ InitialiseArmComputeTensorData(tensor, handle->GetConstTensor<Half>());
+ break;
+ case DataType::Float32:
+ InitialiseArmComputeTensorData(tensor, handle->GetConstTensor<float>());
+ break;
+ default:
+ BOOST_ASSERT_MSG(false, "Unexpected floating point type.");
+ }
+};
+
} //namespace armnn