aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/layers/LstmLayer.cpp
diff options
context:
space:
mode:
authorJim Flynn <jim.flynn@arm.com>2019-03-19 17:22:29 +0000
committerJim Flynn <jim.flynn@arm.com>2019-03-21 16:09:19 +0000
commit11af375a5a6bf88b4f3b933a86d53000b0d91ed0 (patch)
treef4f4db5192b275be44d96d96c7f3c8c10f15b3f1 /src/armnn/layers/LstmLayer.cpp
parentdb059fd50f9afb398b8b12cd4592323fc8f60d7f (diff)
downloadarmnn-11af375a5a6bf88b4f3b933a86d53000b0d91ed0.tar.gz
IVGCVSW-2694: serialize/deserialize LSTM
* added serialize/deserialize methods for LSTM and tests Change-Id: Ic59557f03001c496008c4bef92c2e0406e1fbc6c Signed-off-by: Nina Drozd <nina.drozd@arm.com> Signed-off-by: Jim Flynn <jim.flynn@arm.com>
Diffstat (limited to 'src/armnn/layers/LstmLayer.cpp')
-rw-r--r--src/armnn/layers/LstmLayer.cpp92
1 files changed, 63 insertions, 29 deletions
diff --git a/src/armnn/layers/LstmLayer.cpp b/src/armnn/layers/LstmLayer.cpp
index fa836d0317..2b99f284e8 100644
--- a/src/armnn/layers/LstmLayer.cpp
+++ b/src/armnn/layers/LstmLayer.cpp
@@ -252,110 +252,144 @@ Layer::ConstantTensors LstmLayer::GetConstantTensorsByRef()
void LstmLayer::Accept(ILayerVisitor& visitor) const
{
LstmInputParams inputParams;
+ ConstTensor inputToInputWeightsTensor;
if (m_CifgParameters.m_InputToInputWeights != nullptr)
{
- ConstTensor inputToInputWeightsTensor(m_CifgParameters.m_InputToInputWeights->GetTensorInfo(),
- m_CifgParameters.m_InputToInputWeights->Map(true));
+ ConstTensor inputToInputWeightsTensorCopy(m_CifgParameters.m_InputToInputWeights->GetTensorInfo(),
+ m_CifgParameters.m_InputToInputWeights->Map(true));
+ inputToInputWeightsTensor = inputToInputWeightsTensorCopy;
inputParams.m_InputToInputWeights = &inputToInputWeightsTensor;
}
+ ConstTensor inputToForgetWeightsTensor;
if (m_BasicParameters.m_InputToForgetWeights != nullptr)
{
- ConstTensor inputToForgetWeightsTensor(m_BasicParameters.m_InputToForgetWeights->GetTensorInfo(),
- m_BasicParameters.m_InputToForgetWeights->Map(true));
+ ConstTensor inputToForgetWeightsTensorCopy(m_BasicParameters.m_InputToForgetWeights->GetTensorInfo(),
+ m_BasicParameters.m_InputToForgetWeights->Map(true));
+ inputToForgetWeightsTensor = inputToForgetWeightsTensorCopy;
inputParams.m_InputToForgetWeights = &inputToForgetWeightsTensor;
}
+ ConstTensor inputToCellWeightsTensor;
if (m_BasicParameters.m_InputToCellWeights != nullptr)
{
- ConstTensor inputToCellWeightsTensor(m_BasicParameters.m_InputToCellWeights->GetTensorInfo(),
- m_BasicParameters.m_InputToCellWeights->Map(true));
+ ConstTensor inputToCellWeightsTensorCopy(m_BasicParameters.m_InputToCellWeights->GetTensorInfo(),
+ m_BasicParameters.m_InputToCellWeights->Map(true));
+ inputToCellWeightsTensor = inputToCellWeightsTensorCopy;
inputParams.m_InputToCellWeights = &inputToCellWeightsTensor;
}
+ ConstTensor inputToOutputWeightsTensor;
if (m_BasicParameters.m_InputToOutputWeights != nullptr)
{
- ConstTensor inputToOutputWeightsTensor(m_BasicParameters.m_InputToOutputWeights->GetTensorInfo(),
- m_BasicParameters.m_InputToOutputWeights->Map(true));
+ ConstTensor inputToOutputWeightsTensorCopy(m_BasicParameters.m_InputToOutputWeights->GetTensorInfo(),
+ m_BasicParameters.m_InputToOutputWeights->Map(true));
+ inputToOutputWeightsTensor = inputToOutputWeightsTensorCopy;
inputParams.m_InputToOutputWeights = &inputToOutputWeightsTensor;
}
+ ConstTensor recurrentToInputWeightsTensor;
if (m_CifgParameters.m_RecurrentToInputWeights != nullptr)
{
- ConstTensor recurrentToInputWeightsTensor(
+ ConstTensor recurrentToInputWeightsTensorCopy(
m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo(),
m_CifgParameters.m_RecurrentToInputWeights->Map(true));
+ recurrentToInputWeightsTensor = recurrentToInputWeightsTensorCopy;
inputParams.m_RecurrentToInputWeights = &recurrentToInputWeightsTensor;
}
+ ConstTensor recurrentToForgetWeightsTensor;
if (m_BasicParameters.m_RecurrentToForgetWeights != nullptr)
{
- ConstTensor recurrentToForgetWeightsTensor(
+ ConstTensor recurrentToForgetWeightsTensorCopy(
m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo(),
m_BasicParameters.m_RecurrentToForgetWeights->Map(true));
+ recurrentToForgetWeightsTensor = recurrentToForgetWeightsTensorCopy;
inputParams.m_RecurrentToForgetWeights = &recurrentToForgetWeightsTensor;
}
+ ConstTensor recurrentToCellWeightsTensor;
if (m_BasicParameters.m_RecurrentToCellWeights != nullptr)
{
- ConstTensor recurrentToCellWeightsTensor(
+ ConstTensor recurrentToCellWeightsTensorCopy(
m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo(),
m_BasicParameters.m_RecurrentToCellWeights->Map(true));
+ recurrentToCellWeightsTensor = recurrentToCellWeightsTensorCopy;
inputParams.m_RecurrentToCellWeights = &recurrentToCellWeightsTensor;
}
+ ConstTensor recurrentToOutputWeightsTensor;
if (m_BasicParameters.m_RecurrentToOutputWeights != nullptr)
{
- ConstTensor recurrentToOutputWeightsTensor(
+ ConstTensor recurrentToOutputWeightsTensorCopy(
m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo(),
m_BasicParameters.m_RecurrentToOutputWeights->Map(true));
+ recurrentToOutputWeightsTensor = recurrentToOutputWeightsTensorCopy;
inputParams.m_RecurrentToOutputWeights = &recurrentToOutputWeightsTensor;
}
+ ConstTensor cellToInputWeightsTensor;
if (m_CifgParameters.m_CellToInputWeights != nullptr)
{
- ConstTensor cellToInputWeightsTensor(m_CifgParameters.m_CellToInputWeights->GetTensorInfo(),
- m_CifgParameters.m_CellToInputWeights->Map(true));
+ ConstTensor cellToInputWeightsTensorCopy(m_CifgParameters.m_CellToInputWeights->GetTensorInfo(),
+ m_CifgParameters.m_CellToInputWeights->Map(true));
+ cellToInputWeightsTensor = cellToInputWeightsTensorCopy;
inputParams.m_CellToInputWeights = &cellToInputWeightsTensor;
}
+ ConstTensor cellToForgetWeightsTensor;
if (m_PeepholeParameters.m_CellToForgetWeights != nullptr)
{
- ConstTensor cellToForgetWeightsTensor(m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo(),
- m_PeepholeParameters.m_CellToForgetWeights->Map(true));
+ ConstTensor cellToForgetWeightsTensorCopy(m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo(),
+ m_PeepholeParameters.m_CellToForgetWeights->Map(true));
+ cellToForgetWeightsTensor = cellToForgetWeightsTensorCopy;
inputParams.m_CellToForgetWeights = &cellToForgetWeightsTensor;
}
+ ConstTensor cellToOutputWeightsTensor;
if (m_PeepholeParameters.m_CellToOutputWeights != nullptr)
{
- ConstTensor cellToOutputWeightsTensor(m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo(),
- m_PeepholeParameters.m_CellToOutputWeights->Map(true));
+ ConstTensor cellToOutputWeightsTensorCopy(m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo(),
+ m_PeepholeParameters.m_CellToOutputWeights->Map(true));
+ cellToOutputWeightsTensor = cellToOutputWeightsTensorCopy;
inputParams.m_CellToOutputWeights = &cellToOutputWeightsTensor;
}
+ ConstTensor inputGateBiasTensor;
if (m_CifgParameters.m_InputGateBias != nullptr)
{
- ConstTensor inputGateBiasTensor(m_CifgParameters.m_InputGateBias->GetTensorInfo(),
+ ConstTensor inputGateBiasTensorCopy(m_CifgParameters.m_InputGateBias->GetTensorInfo(),
m_CifgParameters.m_InputGateBias->Map(true));
+ inputGateBiasTensor = inputGateBiasTensorCopy;
inputParams.m_InputGateBias = &inputGateBiasTensor;
}
+ ConstTensor forgetGateBiasTensor;
if (m_BasicParameters.m_ForgetGateBias != nullptr)
{
- ConstTensor forgetGateBiasTensor(m_BasicParameters.m_ForgetGateBias->GetTensorInfo(),
- m_BasicParameters.m_ForgetGateBias->Map(true));
+ ConstTensor forgetGateBiasTensorCopy(m_BasicParameters.m_ForgetGateBias->GetTensorInfo(),
+ m_BasicParameters.m_ForgetGateBias->Map(true));
+ forgetGateBiasTensor = forgetGateBiasTensorCopy;
inputParams.m_ForgetGateBias = &forgetGateBiasTensor;
}
+ ConstTensor cellBiasTensor;
if (m_BasicParameters.m_CellBias != nullptr)
{
- ConstTensor cellBiasTensor(m_BasicParameters.m_CellBias->GetTensorInfo(),
- m_BasicParameters.m_CellBias->Map(true));
+ ConstTensor cellBiasTensorCopy(m_BasicParameters.m_CellBias->GetTensorInfo(),
+ m_BasicParameters.m_CellBias->Map(true));
+ cellBiasTensor = cellBiasTensorCopy;
inputParams.m_CellBias = &cellBiasTensor;
}
+ ConstTensor outputGateBias;
if (m_BasicParameters.m_OutputGateBias != nullptr)
{
- ConstTensor outputGateBias(m_BasicParameters.m_OutputGateBias->GetTensorInfo(),
- m_BasicParameters.m_OutputGateBias->Map(true));
+ ConstTensor outputGateBiasCopy(m_BasicParameters.m_OutputGateBias->GetTensorInfo(),
+ m_BasicParameters.m_OutputGateBias->Map(true));
+ outputGateBias = outputGateBiasCopy;
inputParams.m_OutputGateBias = &outputGateBias;
}
+ ConstTensor projectionWeightsTensor;
if (m_ProjectionParameters.m_ProjectionWeights != nullptr)
{
- ConstTensor projectionWeightsTensor(m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo(),
- m_ProjectionParameters.m_ProjectionWeights->Map(true));
+ ConstTensor projectionWeightsTensorCopy(m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo(),
+ m_ProjectionParameters.m_ProjectionWeights->Map(true));
+ projectionWeightsTensor = projectionWeightsTensorCopy;
inputParams.m_ProjectionWeights = &projectionWeightsTensor;
}
+ ConstTensor projectionBiasTensor;
if (m_ProjectionParameters.m_ProjectionBias != nullptr)
{
- ConstTensor projectionBiasTensor(m_ProjectionParameters.m_ProjectionBias->GetTensorInfo(),
- m_ProjectionParameters.m_ProjectionBias->Map(true));
+ ConstTensor projectionBiasTensorCopy(m_ProjectionParameters.m_ProjectionBias->GetTensorInfo(),
+ m_ProjectionParameters.m_ProjectionBias->Map(true));
+ projectionBiasTensor = projectionBiasTensorCopy;
inputParams.m_ProjectionBias = &projectionBiasTensor;
}