diff options
Diffstat (limited to 'src/backends/cl/workloads')
-rw-r--r-- | src/backends/cl/workloads/ClDequantizeWorkload.cpp | 2 | ||||
-rw-r--r-- | src/backends/cl/workloads/ClQuantizeWorkload.cpp | 2 | ||||
-rw-r--r-- | src/backends/cl/workloads/ClWorkloadUtils.hpp | 1 |
3 files changed, 5 insertions, 0 deletions
diff --git a/src/backends/cl/workloads/ClDequantizeWorkload.cpp b/src/backends/cl/workloads/ClDequantizeWorkload.cpp index 67a555a020..eca795de7e 100644 --- a/src/backends/cl/workloads/ClDequantizeWorkload.cpp +++ b/src/backends/cl/workloads/ClDequantizeWorkload.cpp @@ -32,6 +32,8 @@ ClDequantizeWorkload::ClDequantizeWorkload(const DequantizeQueueDescriptor& desc const WorkloadInfo& workloadInfo) : BaseWorkload<DequantizeQueueDescriptor>(descriptor, workloadInfo) { + m_Data.ValidateInputsOutputs("ClDequantizeWorkload", 1, 1); + arm_compute::ICLTensor& input = boost::polymorphic_pointer_downcast<IClTensorHandle>( m_Data.m_Inputs[0])->GetTensor(); diff --git a/src/backends/cl/workloads/ClQuantizeWorkload.cpp b/src/backends/cl/workloads/ClQuantizeWorkload.cpp index 230e346a00..263065a5a4 100644 --- a/src/backends/cl/workloads/ClQuantizeWorkload.cpp +++ b/src/backends/cl/workloads/ClQuantizeWorkload.cpp @@ -32,6 +32,8 @@ arm_compute::Status ClQuantizeWorkloadValidate(const TensorInfo& input, ClQuantizeWorkload::ClQuantizeWorkload(const QuantizeQueueDescriptor& descriptor, const WorkloadInfo& info) : BaseWorkload<QuantizeQueueDescriptor>(descriptor, info) { + m_Data.ValidateInputsOutputs("ClQuantizeWorkload", 1, 1); + arm_compute::ICLTensor& input = static_cast<IClTensorHandle*>(m_Data.m_Inputs[0])->GetTensor(); arm_compute::ICLTensor& output = static_cast<IClTensorHandle*>(m_Data.m_Outputs[0])->GetTensor(); diff --git a/src/backends/cl/workloads/ClWorkloadUtils.hpp b/src/backends/cl/workloads/ClWorkloadUtils.hpp index d3c6df50ed..b4bcc1c017 100644 --- a/src/backends/cl/workloads/ClWorkloadUtils.hpp +++ b/src/backends/cl/workloads/ClWorkloadUtils.hpp @@ -101,6 +101,7 @@ inline void InitializeArmComputeClTensorData(arm_compute::CLTensor& clTensor, case DataType::Float32: CopyArmComputeClTensorData(clTensor, handle->GetConstTensor<float>()); break; + case DataType::QAsymmS8: case DataType::QAsymmU8: CopyArmComputeClTensorData(clTensor, handle->GetConstTensor<uint8_t>()); break; |