diff options
Diffstat (limited to 'src/armnn/backends/ClWorkloads/ClBaseConstantWorkload.cpp')
-rw-r--r-- | src/armnn/backends/ClWorkloads/ClBaseConstantWorkload.cpp | 20 |
1 files changed, 14 insertions, 6 deletions
diff --git a/src/armnn/backends/ClWorkloads/ClBaseConstantWorkload.cpp b/src/armnn/backends/ClWorkloads/ClBaseConstantWorkload.cpp index 4b72d92d72..e0bc365053 100644 --- a/src/armnn/backends/ClWorkloads/ClBaseConstantWorkload.cpp +++ b/src/armnn/backends/ClWorkloads/ClBaseConstantWorkload.cpp @@ -4,17 +4,19 @@ // #include "ClBaseConstantWorkload.hpp" +#include "backends/ArmComputeTensorUtils.hpp" #include "backends/ClTensorHandle.hpp" #include "backends/CpuTensorHandle.hpp" +#include "Half.hpp" namespace armnn { -template class ClBaseConstantWorkload<DataType::Float32>; +template class ClBaseConstantWorkload<DataType::Float16, DataType::Float32>; template class ClBaseConstantWorkload<DataType::QuantisedAsymm8>; -template<armnn::DataType dataType> -void ClBaseConstantWorkload<dataType>::Execute() const +template<armnn::DataType... dataTypes> +void ClBaseConstantWorkload<dataTypes...>::Execute() const { // The intermediate tensor held by the corresponding layer output handler can be initialised with the given data // on the first inference, then reused for subsequent inferences. @@ -26,15 +28,21 @@ void ClBaseConstantWorkload<dataType>::Execute() const BOOST_ASSERT(data.m_LayerOutput != nullptr); arm_compute::CLTensor& output = static_cast<ClTensorHandle*>(data.m_Outputs[0])->GetTensor(); + arm_compute::DataType computeDataType = static_cast<ClTensorHandle*>(data.m_Outputs[0])->GetDataType(); - switch (dataType) + switch (computeDataType) { - case DataType::Float32: + case arm_compute::DataType::F16: + { + CopyArmComputeClTensorData(data.m_LayerOutput->GetConstTensor<Half>(), output); + break; + } + case arm_compute::DataType::F32: { CopyArmComputeClTensorData(data.m_LayerOutput->GetConstTensor<float>(), output); break; } - case DataType::QuantisedAsymm8: + case arm_compute::DataType::QASYMM8: { CopyArmComputeClTensorData(data.m_LayerOutput->GetConstTensor<uint8_t>(), output); break; |