diff options
Diffstat (limited to 'src/backends/ClWorkloads/ClWorkloadUtils.hpp')
-rw-r--r-- | src/backends/ClWorkloads/ClWorkloadUtils.hpp | 12 |
1 files changed, 9 insertions, 3 deletions
diff --git a/src/backends/ClWorkloads/ClWorkloadUtils.hpp b/src/backends/ClWorkloads/ClWorkloadUtils.hpp index 6f1b155745..a10237cf40 100644 --- a/src/backends/ClWorkloads/ClWorkloadUtils.hpp +++ b/src/backends/ClWorkloads/ClWorkloadUtils.hpp @@ -42,8 +42,8 @@ void InitialiseArmComputeClTensorData(arm_compute::CLTensor& clTensor, const T* CopyArmComputeClTensorData<T>(data, clTensor); } -inline void InitializeArmComputeClTensorDataForFloatTypes(arm_compute::CLTensor& clTensor, - const ConstCpuTensorHandle *handle) +inline void InitializeArmComputeClTensorData(arm_compute::CLTensor& clTensor, + const ConstCpuTensorHandle* handle) { BOOST_ASSERT(handle); switch(handle->GetTensorInfo().GetDataType()) @@ -54,8 +54,14 @@ inline void InitializeArmComputeClTensorDataForFloatTypes(arm_compute::CLTensor& case DataType::Float32: InitialiseArmComputeClTensorData(clTensor, handle->GetConstTensor<float>()); break; + case DataType::QuantisedAsymm8: + InitialiseArmComputeClTensorData(clTensor, handle->GetConstTensor<uint8_t>()); + break; + case DataType::Signed32: + InitialiseArmComputeClTensorData(clTensor, handle->GetConstTensor<int32_t>()); + break; default: - BOOST_ASSERT_MSG(false, "Unexpected floating point type."); + BOOST_ASSERT_MSG(false, "Unexpected tensor type."); } }; |