aboutsummaryrefslogtreecommitdiff
path: root/src/backends/ClWorkloads/ClLstmFloatWorkload.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/ClWorkloads/ClLstmFloatWorkload.cpp')
-rw-r--r--src/backends/ClWorkloads/ClLstmFloatWorkload.cpp51
1 files changed, 17 insertions, 34 deletions
diff --git a/src/backends/ClWorkloads/ClLstmFloatWorkload.cpp b/src/backends/ClWorkloads/ClLstmFloatWorkload.cpp
index 09a34c2d02..8e2c875bab 100644
--- a/src/backends/ClWorkloads/ClLstmFloatWorkload.cpp
+++ b/src/backends/ClWorkloads/ClLstmFloatWorkload.cpp
@@ -172,57 +172,40 @@ ClLstmFloatWorkload::ClLstmFloatWorkload(const LstmQueueDescriptor &descriptor,
armcomputetensorutils::InitialiseArmComputeTensorEmpty(*m_ScratchBuffer);
- InitialiseArmComputeClTensorData(*m_InputToForgetWeightsTensor,
- m_Data.m_InputToForgetWeights->GetConstTensor<float>());
- InitialiseArmComputeClTensorData(*m_InputToCellWeightsTensor,
- m_Data.m_InputToCellWeights->GetConstTensor<float>());
- InitialiseArmComputeClTensorData(*m_InputToOutputWeightsTensor,
- m_Data.m_InputToOutputWeights->GetConstTensor<float>());
- InitialiseArmComputeClTensorData(*m_RecurrentToForgetWeightsTensor,
- m_Data.m_RecurrentToForgetWeights->GetConstTensor<float>());
- InitialiseArmComputeClTensorData(*m_RecurrentToCellWeightsTensor,
- m_Data.m_RecurrentToCellWeights->GetConstTensor<float>());
- InitialiseArmComputeClTensorData(*m_RecurrentToOutputWeightsTensor,
- m_Data.m_RecurrentToOutputWeights->GetConstTensor<float>());
- InitialiseArmComputeClTensorData(*m_ForgetGateBiasTensor,
- m_Data.m_ForgetGateBias->GetConstTensor<float>());
- InitialiseArmComputeClTensorData(*m_CellBiasTensor,
- m_Data.m_CellBias->GetConstTensor<float>());
- InitialiseArmComputeClTensorData(*m_OutputGateBiasTensor,
- m_Data.m_OutputGateBias->GetConstTensor<float>());
+ InitializeArmComputeClTensorData(*m_InputToForgetWeightsTensor, m_Data.m_InputToForgetWeights);
+ InitializeArmComputeClTensorData(*m_InputToCellWeightsTensor, m_Data.m_InputToCellWeights);
+ InitializeArmComputeClTensorData(*m_InputToOutputWeightsTensor, m_Data.m_InputToOutputWeights);
+ InitializeArmComputeClTensorData(*m_RecurrentToForgetWeightsTensor, m_Data.m_RecurrentToForgetWeights);
+ InitializeArmComputeClTensorData(*m_RecurrentToCellWeightsTensor, m_Data.m_RecurrentToCellWeights);
+ InitializeArmComputeClTensorData(*m_RecurrentToOutputWeightsTensor, m_Data.m_RecurrentToOutputWeights);
+ InitializeArmComputeClTensorData(*m_ForgetGateBiasTensor, m_Data.m_ForgetGateBias);
+ InitializeArmComputeClTensorData(*m_CellBiasTensor, m_Data.m_CellBias);
+ InitializeArmComputeClTensorData(*m_OutputGateBiasTensor, m_Data.m_OutputGateBias);
if (!m_Data.m_Parameters.m_CifgEnabled)
{
- InitialiseArmComputeClTensorData(*m_InputToInputWeightsTensor,
- m_Data.m_InputToInputWeights->GetConstTensor<float>());
- InitialiseArmComputeClTensorData(*m_RecurrentToInputWeightsTensor,
- m_Data.m_RecurrentToInputWeights->GetConstTensor<float>());
+ InitializeArmComputeClTensorData(*m_InputToInputWeightsTensor, m_Data.m_InputToInputWeights);
+ InitializeArmComputeClTensorData(*m_RecurrentToInputWeightsTensor, m_Data.m_RecurrentToInputWeights);
if (m_Data.m_CellToInputWeights != nullptr)
{
- InitialiseArmComputeClTensorData(*m_CellToInputWeightsTensor,
- m_Data.m_CellToInputWeights->GetConstTensor<float>());
+ InitializeArmComputeClTensorData(*m_CellToInputWeightsTensor, m_Data.m_CellToInputWeights);
}
- InitialiseArmComputeClTensorData(*m_InputGateBiasTensor,
- m_Data.m_InputGateBias->GetConstTensor<float>());
+ InitializeArmComputeClTensorData(*m_InputGateBiasTensor, m_Data.m_InputGateBias);
}
if (m_Data.m_Parameters.m_ProjectionEnabled)
{
- InitialiseArmComputeClTensorData(*m_ProjectionWeightsTensor,
- m_Data.m_ProjectionWeights->GetConstTensor<float>());
+ InitializeArmComputeClTensorData(*m_ProjectionWeightsTensor, m_Data.m_ProjectionWeights);
if (m_Data.m_ProjectionBias != nullptr)
{
- InitialiseArmComputeClTensorData(*m_ProjectionBiasTensor,
- m_Data.m_ProjectionBias->GetConstTensor<float>());
+ InitializeArmComputeClTensorData(*m_ProjectionBiasTensor, m_Data.m_ProjectionBias);
}
}
if (m_Data.m_Parameters.m_PeepholeEnabled)
{
- InitialiseArmComputeClTensorData(*m_CellToForgetWeightsTensor,
- m_Data.m_CellToForgetWeights->GetConstTensor<float>());
- InitialiseArmComputeClTensorData(*m_CellToOutputWeightsTensor,
- m_Data.m_CellToOutputWeights->GetConstTensor<float>());
+ InitializeArmComputeClTensorData(*m_CellToForgetWeightsTensor, m_Data.m_CellToForgetWeights);
+ InitializeArmComputeClTensorData(*m_CellToOutputWeightsTensor, m_Data.m_CellToOutputWeights);
}
// Force Compute Library to perform the necessary copying and reshaping, after which