From 11af375a5a6bf88b4f3b933a86d53000b0d91ed0 Mon Sep 17 00:00:00 2001 From: Jim Flynn Date: Tue, 19 Mar 2019 17:22:29 +0000 Subject: IVGCVSW-2694: serialize/deserialize LSTM * added serialize/deserialize methods for LSTM and tests Change-Id: Ic59557f03001c496008c4bef92c2e0406e1fbc6c Signed-off-by: Nina Drozd Signed-off-by: Jim Flynn --- src/armnn/layers/LstmLayer.cpp | 92 +++++++++++++++++++++++++++++------------- 1 file changed, 63 insertions(+), 29 deletions(-) (limited to 'src/armnn/layers/LstmLayer.cpp') diff --git a/src/armnn/layers/LstmLayer.cpp b/src/armnn/layers/LstmLayer.cpp index fa836d0317..2b99f284e8 100644 --- a/src/armnn/layers/LstmLayer.cpp +++ b/src/armnn/layers/LstmLayer.cpp @@ -252,110 +252,144 @@ Layer::ConstantTensors LstmLayer::GetConstantTensorsByRef() void LstmLayer::Accept(ILayerVisitor& visitor) const { LstmInputParams inputParams; + ConstTensor inputToInputWeightsTensor; if (m_CifgParameters.m_InputToInputWeights != nullptr) { - ConstTensor inputToInputWeightsTensor(m_CifgParameters.m_InputToInputWeights->GetTensorInfo(), - m_CifgParameters.m_InputToInputWeights->Map(true)); + ConstTensor inputToInputWeightsTensorCopy(m_CifgParameters.m_InputToInputWeights->GetTensorInfo(), + m_CifgParameters.m_InputToInputWeights->Map(true)); + inputToInputWeightsTensor = inputToInputWeightsTensorCopy; inputParams.m_InputToInputWeights = &inputToInputWeightsTensor; } + ConstTensor inputToForgetWeightsTensor; if (m_BasicParameters.m_InputToForgetWeights != nullptr) { - ConstTensor inputToForgetWeightsTensor(m_BasicParameters.m_InputToForgetWeights->GetTensorInfo(), - m_BasicParameters.m_InputToForgetWeights->Map(true)); + ConstTensor inputToForgetWeightsTensorCopy(m_BasicParameters.m_InputToForgetWeights->GetTensorInfo(), + m_BasicParameters.m_InputToForgetWeights->Map(true)); + inputToForgetWeightsTensor = inputToForgetWeightsTensorCopy; inputParams.m_InputToForgetWeights = &inputToForgetWeightsTensor; } + ConstTensor inputToCellWeightsTensor; if (m_BasicParameters.m_InputToCellWeights != nullptr) { - ConstTensor inputToCellWeightsTensor(m_BasicParameters.m_InputToCellWeights->GetTensorInfo(), - m_BasicParameters.m_InputToCellWeights->Map(true)); + ConstTensor inputToCellWeightsTensorCopy(m_BasicParameters.m_InputToCellWeights->GetTensorInfo(), + m_BasicParameters.m_InputToCellWeights->Map(true)); + inputToCellWeightsTensor = inputToCellWeightsTensorCopy; inputParams.m_InputToCellWeights = &inputToCellWeightsTensor; } + ConstTensor inputToOutputWeightsTensor; if (m_BasicParameters.m_InputToOutputWeights != nullptr) { - ConstTensor inputToOutputWeightsTensor(m_BasicParameters.m_InputToOutputWeights->GetTensorInfo(), - m_BasicParameters.m_InputToOutputWeights->Map(true)); + ConstTensor inputToOutputWeightsTensorCopy(m_BasicParameters.m_InputToOutputWeights->GetTensorInfo(), + m_BasicParameters.m_InputToOutputWeights->Map(true)); + inputToOutputWeightsTensor = inputToOutputWeightsTensorCopy; inputParams.m_InputToOutputWeights = &inputToOutputWeightsTensor; } + ConstTensor recurrentToInputWeightsTensor; if (m_CifgParameters.m_RecurrentToInputWeights != nullptr) { - ConstTensor recurrentToInputWeightsTensor( + ConstTensor recurrentToInputWeightsTensorCopy( m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo(), m_CifgParameters.m_RecurrentToInputWeights->Map(true)); + recurrentToInputWeightsTensor = recurrentToInputWeightsTensorCopy; inputParams.m_RecurrentToInputWeights = &recurrentToInputWeightsTensor; } + ConstTensor recurrentToForgetWeightsTensor; if (m_BasicParameters.m_RecurrentToForgetWeights != nullptr) { - ConstTensor recurrentToForgetWeightsTensor( + ConstTensor recurrentToForgetWeightsTensorCopy( m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo(), m_BasicParameters.m_RecurrentToForgetWeights->Map(true)); + recurrentToForgetWeightsTensor = recurrentToForgetWeightsTensorCopy; inputParams.m_RecurrentToForgetWeights = &recurrentToForgetWeightsTensor; } + ConstTensor recurrentToCellWeightsTensor; if (m_BasicParameters.m_RecurrentToCellWeights != nullptr) { - ConstTensor recurrentToCellWeightsTensor( + ConstTensor recurrentToCellWeightsTensorCopy( m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo(), m_BasicParameters.m_RecurrentToCellWeights->Map(true)); + recurrentToCellWeightsTensor = recurrentToCellWeightsTensorCopy; inputParams.m_RecurrentToCellWeights = &recurrentToCellWeightsTensor; } + ConstTensor recurrentToOutputWeightsTensor; if (m_BasicParameters.m_RecurrentToOutputWeights != nullptr) { - ConstTensor recurrentToOutputWeightsTensor( + ConstTensor recurrentToOutputWeightsTensorCopy( m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo(), m_BasicParameters.m_RecurrentToOutputWeights->Map(true)); + recurrentToOutputWeightsTensor = recurrentToOutputWeightsTensorCopy; inputParams.m_RecurrentToOutputWeights = &recurrentToOutputWeightsTensor; } + ConstTensor cellToInputWeightsTensor; if (m_CifgParameters.m_CellToInputWeights != nullptr) { - ConstTensor cellToInputWeightsTensor(m_CifgParameters.m_CellToInputWeights->GetTensorInfo(), - m_CifgParameters.m_CellToInputWeights->Map(true)); + ConstTensor cellToInputWeightsTensorCopy(m_CifgParameters.m_CellToInputWeights->GetTensorInfo(), + m_CifgParameters.m_CellToInputWeights->Map(true)); + cellToInputWeightsTensor = cellToInputWeightsTensorCopy; inputParams.m_CellToInputWeights = &cellToInputWeightsTensor; } + ConstTensor cellToForgetWeightsTensor; if (m_PeepholeParameters.m_CellToForgetWeights != nullptr) { - ConstTensor cellToForgetWeightsTensor(m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo(), - m_PeepholeParameters.m_CellToForgetWeights->Map(true)); + ConstTensor cellToForgetWeightsTensorCopy(m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo(), + m_PeepholeParameters.m_CellToForgetWeights->Map(true)); + cellToForgetWeightsTensor = cellToForgetWeightsTensorCopy; inputParams.m_CellToForgetWeights = &cellToForgetWeightsTensor; } + ConstTensor cellToOutputWeightsTensor; if (m_PeepholeParameters.m_CellToOutputWeights != nullptr) { - ConstTensor cellToOutputWeightsTensor(m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo(), - m_PeepholeParameters.m_CellToOutputWeights->Map(true)); + ConstTensor cellToOutputWeightsTensorCopy(m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo(), + m_PeepholeParameters.m_CellToOutputWeights->Map(true)); + cellToOutputWeightsTensor = cellToOutputWeightsTensorCopy; inputParams.m_CellToOutputWeights = &cellToOutputWeightsTensor; } + ConstTensor inputGateBiasTensor; if (m_CifgParameters.m_InputGateBias != nullptr) { - ConstTensor inputGateBiasTensor(m_CifgParameters.m_InputGateBias->GetTensorInfo(), + ConstTensor inputGateBiasTensorCopy(m_CifgParameters.m_InputGateBias->GetTensorInfo(), m_CifgParameters.m_InputGateBias->Map(true)); + inputGateBiasTensor = inputGateBiasTensorCopy; inputParams.m_InputGateBias = &inputGateBiasTensor; } + ConstTensor forgetGateBiasTensor; if (m_BasicParameters.m_ForgetGateBias != nullptr) { - ConstTensor forgetGateBiasTensor(m_BasicParameters.m_ForgetGateBias->GetTensorInfo(), - m_BasicParameters.m_ForgetGateBias->Map(true)); + ConstTensor forgetGateBiasTensorCopy(m_BasicParameters.m_ForgetGateBias->GetTensorInfo(), + m_BasicParameters.m_ForgetGateBias->Map(true)); + forgetGateBiasTensor = forgetGateBiasTensorCopy; inputParams.m_ForgetGateBias = &forgetGateBiasTensor; } + ConstTensor cellBiasTensor; if (m_BasicParameters.m_CellBias != nullptr) { - ConstTensor cellBiasTensor(m_BasicParameters.m_CellBias->GetTensorInfo(), - m_BasicParameters.m_CellBias->Map(true)); + ConstTensor cellBiasTensorCopy(m_BasicParameters.m_CellBias->GetTensorInfo(), + m_BasicParameters.m_CellBias->Map(true)); + cellBiasTensor = cellBiasTensorCopy; inputParams.m_CellBias = &cellBiasTensor; } + ConstTensor outputGateBias; if (m_BasicParameters.m_OutputGateBias != nullptr) { - ConstTensor outputGateBias(m_BasicParameters.m_OutputGateBias->GetTensorInfo(), - m_BasicParameters.m_OutputGateBias->Map(true)); + ConstTensor outputGateBiasCopy(m_BasicParameters.m_OutputGateBias->GetTensorInfo(), + m_BasicParameters.m_OutputGateBias->Map(true)); + outputGateBias = outputGateBiasCopy; inputParams.m_OutputGateBias = &outputGateBias; } + ConstTensor projectionWeightsTensor; if (m_ProjectionParameters.m_ProjectionWeights != nullptr) { - ConstTensor projectionWeightsTensor(m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo(), - m_ProjectionParameters.m_ProjectionWeights->Map(true)); + ConstTensor projectionWeightsTensorCopy(m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo(), + m_ProjectionParameters.m_ProjectionWeights->Map(true)); + projectionWeightsTensor = projectionWeightsTensorCopy; inputParams.m_ProjectionWeights = &projectionWeightsTensor; } + ConstTensor projectionBiasTensor; if (m_ProjectionParameters.m_ProjectionBias != nullptr) { - ConstTensor projectionBiasTensor(m_ProjectionParameters.m_ProjectionBias->GetTensorInfo(), - m_ProjectionParameters.m_ProjectionBias->Map(true)); + ConstTensor projectionBiasTensorCopy(m_ProjectionParameters.m_ProjectionBias->GetTensorInfo(), + m_ProjectionParameters.m_ProjectionBias->Map(true)); + projectionBiasTensor = projectionBiasTensorCopy; inputParams.m_ProjectionBias = &projectionBiasTensor; } -- cgit v1.2.1