aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/backends/ClWorkloads/ClBaseConstantWorkload.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/backends/ClWorkloads/ClBaseConstantWorkload.cpp')
-rw-r--r--src/armnn/backends/ClWorkloads/ClBaseConstantWorkload.cpp20
1 files changed, 14 insertions, 6 deletions
diff --git a/src/armnn/backends/ClWorkloads/ClBaseConstantWorkload.cpp b/src/armnn/backends/ClWorkloads/ClBaseConstantWorkload.cpp
index 4b72d92d72..e0bc365053 100644
--- a/src/armnn/backends/ClWorkloads/ClBaseConstantWorkload.cpp
+++ b/src/armnn/backends/ClWorkloads/ClBaseConstantWorkload.cpp
@@ -4,17 +4,19 @@
//
#include "ClBaseConstantWorkload.hpp"
+#include "backends/ArmComputeTensorUtils.hpp"
#include "backends/ClTensorHandle.hpp"
#include "backends/CpuTensorHandle.hpp"
+#include "Half.hpp"
namespace armnn
{
-template class ClBaseConstantWorkload<DataType::Float32>;
+template class ClBaseConstantWorkload<DataType::Float16, DataType::Float32>;
template class ClBaseConstantWorkload<DataType::QuantisedAsymm8>;
-template<armnn::DataType dataType>
-void ClBaseConstantWorkload<dataType>::Execute() const
+template<armnn::DataType... dataTypes>
+void ClBaseConstantWorkload<dataTypes...>::Execute() const
{
// The intermediate tensor held by the corresponding layer output handler can be initialised with the given data
// on the first inference, then reused for subsequent inferences.
@@ -26,15 +28,21 @@ void ClBaseConstantWorkload<dataType>::Execute() const
BOOST_ASSERT(data.m_LayerOutput != nullptr);
arm_compute::CLTensor& output = static_cast<ClTensorHandle*>(data.m_Outputs[0])->GetTensor();
+ arm_compute::DataType computeDataType = static_cast<ClTensorHandle*>(data.m_Outputs[0])->GetDataType();
- switch (dataType)
+ switch (computeDataType)
{
- case DataType::Float32:
+ case arm_compute::DataType::F16:
+ {
+ CopyArmComputeClTensorData(data.m_LayerOutput->GetConstTensor<Half>(), output);
+ break;
+ }
+ case arm_compute::DataType::F32:
{
CopyArmComputeClTensorData(data.m_LayerOutput->GetConstTensor<float>(), output);
break;
}
- case DataType::QuantisedAsymm8:
+ case arm_compute::DataType::QASYMM8:
{
CopyArmComputeClTensorData(data.m_LayerOutput->GetConstTensor<uint8_t>(), output);
break;