diff options
Diffstat (limited to 'src/backends/cl/workloads/ClQuantizedLstmWorkload.cpp')
-rw-r--r-- | src/backends/cl/workloads/ClQuantizedLstmWorkload.cpp | 6 |
1 files changed, 4 insertions, 2 deletions
diff --git a/src/backends/cl/workloads/ClQuantizedLstmWorkload.cpp b/src/backends/cl/workloads/ClQuantizedLstmWorkload.cpp index 688ebf9184..636bdecbeb 100644 --- a/src/backends/cl/workloads/ClQuantizedLstmWorkload.cpp +++ b/src/backends/cl/workloads/ClQuantizedLstmWorkload.cpp @@ -62,7 +62,8 @@ arm_compute::Status ClQuantizedLstmWorkloadValidate(const TensorInfo& input, con } ClQuantizedLstmWorkload::ClQuantizedLstmWorkload(const QuantizedLstmQueueDescriptor &descriptor, - const WorkloadInfo &info): + const WorkloadInfo &info, + const arm_compute::CLCompileContext& clCompileContext): BaseWorkload<QuantizedLstmQueueDescriptor>(descriptor, info) { m_InputToInputWeightsTensor = std::make_unique<arm_compute::CLTensor>(); @@ -108,7 +109,8 @@ ClQuantizedLstmWorkload::ClQuantizedLstmWorkload(const QuantizedLstmQueueDescrip arm_compute::ICLTensor& cellStateOutTensor = static_cast<IClTensorHandle*>(m_Data.m_Outputs[0])->GetTensor(); arm_compute::ICLTensor& outputStateOutTensor = static_cast<IClTensorHandle*>(m_Data.m_Outputs[1])->GetTensor(); - m_QuantizedLstmLayer.configure(&inputTensor, m_InputToInputWeightsTensor.get(), m_InputToForgetWeightsTensor.get(), + m_QuantizedLstmLayer.configure(clCompileContext, &inputTensor, m_InputToInputWeightsTensor.get(), + m_InputToForgetWeightsTensor.get(), m_InputToCellWeightsTensor.get(), m_InputToOutputWeightsTensor.get(), m_RecurrentToInputWeightsTensor.get(), m_RecurrentToForgetWeightsTensor.get(), m_RecurrentToCellWeightsTensor.get(), m_RecurrentToOutputWeightsTensor.get(), |