aboutsummaryrefslogtreecommitdiff
path: root/src/backends/ClWorkloads/ClWorkloadUtils.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/ClWorkloads/ClWorkloadUtils.hpp')
-rw-r--r--src/backends/ClWorkloads/ClWorkloadUtils.hpp12
1 files changed, 9 insertions, 3 deletions
diff --git a/src/backends/ClWorkloads/ClWorkloadUtils.hpp b/src/backends/ClWorkloads/ClWorkloadUtils.hpp
index 6f1b155745..a10237cf40 100644
--- a/src/backends/ClWorkloads/ClWorkloadUtils.hpp
+++ b/src/backends/ClWorkloads/ClWorkloadUtils.hpp
@@ -42,8 +42,8 @@ void InitialiseArmComputeClTensorData(arm_compute::CLTensor& clTensor, const T*
CopyArmComputeClTensorData<T>(data, clTensor);
}
-inline void InitializeArmComputeClTensorDataForFloatTypes(arm_compute::CLTensor& clTensor,
- const ConstCpuTensorHandle *handle)
+inline void InitializeArmComputeClTensorData(arm_compute::CLTensor& clTensor,
+ const ConstCpuTensorHandle* handle)
{
BOOST_ASSERT(handle);
switch(handle->GetTensorInfo().GetDataType())
@@ -54,8 +54,14 @@ inline void InitializeArmComputeClTensorDataForFloatTypes(arm_compute::CLTensor&
case DataType::Float32:
InitialiseArmComputeClTensorData(clTensor, handle->GetConstTensor<float>());
break;
+ case DataType::QuantisedAsymm8:
+ InitialiseArmComputeClTensorData(clTensor, handle->GetConstTensor<uint8_t>());
+ break;
+ case DataType::Signed32:
+ InitialiseArmComputeClTensorData(clTensor, handle->GetConstTensor<int32_t>());
+ break;
default:
- BOOST_ASSERT_MSG(false, "Unexpected floating point type.");
+ BOOST_ASSERT_MSG(false, "Unexpected tensor type.");
}
};