aboutsummaryrefslogtreecommitdiff
path: root/src/backends/neon/workloads/NeonWorkloadUtils.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/neon/workloads/NeonWorkloadUtils.hpp')
-rw-r--r--src/backends/neon/workloads/NeonWorkloadUtils.hpp58
1 files changed, 41 insertions, 17 deletions
diff --git a/src/backends/neon/workloads/NeonWorkloadUtils.hpp b/src/backends/neon/workloads/NeonWorkloadUtils.hpp
index c4accd6c89..48ec753546 100644
--- a/src/backends/neon/workloads/NeonWorkloadUtils.hpp
+++ b/src/backends/neon/workloads/NeonWorkloadUtils.hpp
@@ -5,30 +5,54 @@
#pragma once
#include <backends/Workload.hpp>
-
+#include <backends/aclCommon/ArmComputeTensorUtils.hpp>
#include <backends/neon/NeonTensorHandle.hpp>
#include <backends/neon/NeonTimer.hpp>
-
-#include <arm_compute/core/Types.h>
-#include <arm_compute/core/Helpers.h>
+#include <backends/CpuTensorHandle.hpp>
#include <arm_compute/runtime/NEON/NEFunctions.h>
-#include <arm_compute/runtime/SubTensor.h>
-#include <boost/cast.hpp>
+#include <Half.hpp>
+
+#define ARMNN_SCOPED_PROFILING_EVENT_NEON(name) \
+ ARMNN_SCOPED_PROFILING_EVENT_WITH_INSTRUMENTS(armnn::Compute::CpuAcc, \
+ name, \
+ armnn::NeonTimer(), \
+ armnn::WallClockTimer())
+
+using namespace armnn::armcomputetensorutils;
namespace armnn
{
-class Layer;
-
-template<typename T>
-void InitialiseArmComputeTensorData(arm_compute::Tensor& tensor, const T* data);
-void InitializeArmComputeTensorDataForFloatTypes(arm_compute::Tensor& tensor, const ConstCpuTensorHandle* handle);
-} //namespace armnn
+template <typename T>
+void CopyArmComputeTensorData(arm_compute::Tensor& dstTensor, const T* srcData)
+{
+ InitialiseArmComputeTensorEmpty(dstTensor);
+ CopyArmComputeITensorData(srcData, dstTensor);
+}
+inline void InitializeArmComputeTensorData(arm_compute::Tensor& tensor,
+ const ConstCpuTensorHandle* handle)
+{
+ BOOST_ASSERT(handle);
+
+ switch(handle->GetTensorInfo().GetDataType())
+ {
+ case DataType::Float16:
+ CopyArmComputeTensorData(tensor, handle->GetConstTensor<armnn::Half>());
+ break;
+ case DataType::Float32:
+ CopyArmComputeTensorData(tensor, handle->GetConstTensor<float>());
+ break;
+ case DataType::QuantisedAsymm8:
+ CopyArmComputeTensorData(tensor, handle->GetConstTensor<uint8_t>());
+ break;
+ case DataType::Signed32:
+ CopyArmComputeTensorData(tensor, handle->GetConstTensor<int32_t>());
+ break;
+ default:
+ BOOST_ASSERT_MSG(false, "Unexpected tensor type.");
+ }
+};
-#define ARMNN_SCOPED_PROFILING_EVENT_NEON(name) \
- ARMNN_SCOPED_PROFILING_EVENT_WITH_INSTRUMENTS(armnn::Compute::CpuAcc, \
- name, \
- armnn::NeonTimer(), \
- armnn::WallClockTimer())
+} //namespace armnn