diff options
Diffstat (limited to 'src/armnn/layers/QLstmLayer.cpp')
-rw-r--r-- | src/armnn/layers/QLstmLayer.cpp | 226 |
1 files changed, 142 insertions, 84 deletions
diff --git a/src/armnn/layers/QLstmLayer.cpp b/src/armnn/layers/QLstmLayer.cpp index 16aa718eb9..72b020f109 100644 --- a/src/armnn/layers/QLstmLayer.cpp +++ b/src/armnn/layers/QLstmLayer.cpp @@ -305,12 +305,41 @@ Layer::ConstantTensors QLstmLayer::GetConstantTensorsByRef() void QLstmLayer::Accept(ILayerVisitor& visitor) const { LstmInputParams inputParams; + ManagedConstTensorHandle managedInputToForgetWeights(m_BasicParameters.m_InputToForgetWeights); + ManagedConstTensorHandle managedInputToCellWeights(m_BasicParameters.m_InputToCellWeights); + ManagedConstTensorHandle managedInputToOutputWeights(m_BasicParameters.m_InputToOutputWeights); + ManagedConstTensorHandle managedRecurrentToForgetWeights(m_BasicParameters.m_RecurrentToForgetWeights); + ManagedConstTensorHandle managedRecurrentToCellWeights(m_BasicParameters.m_RecurrentToCellWeights); + ManagedConstTensorHandle managedRecurrentToOutputWeights(m_BasicParameters.m_RecurrentToOutputWeights); + ManagedConstTensorHandle managedForgetGateBias(m_BasicParameters.m_ForgetGateBias); + ManagedConstTensorHandle managedCellBias(m_BasicParameters.m_CellBias); + ManagedConstTensorHandle managedOutputGateBias(m_BasicParameters.m_OutputGateBias); + + // Cifg parameters + ManagedConstTensorHandle managedInputToInputWeights(m_CifgParameters.m_InputToInputWeights); + ManagedConstTensorHandle managedRecurrentToInputWeights(m_CifgParameters.m_RecurrentToInputWeights); + ManagedConstTensorHandle managedInputGateBias(m_CifgParameters.m_InputGateBias); + + // Projection parameters + ManagedConstTensorHandle managedProjectionWeights(m_ProjectionParameters.m_ProjectionWeights); + ManagedConstTensorHandle managedProjectionBias(m_ProjectionParameters.m_ProjectionBias); + + // Peephole parameters + ManagedConstTensorHandle managedCellToInputWeights(m_PeepholeParameters.m_CellToInputWeights); + ManagedConstTensorHandle managedCellToForgetWeights(m_PeepholeParameters.m_CellToForgetWeights); + ManagedConstTensorHandle managedCellToOutputWeights(m_PeepholeParameters.m_CellToOutputWeights); + + // Layer normalisation parameters + ManagedConstTensorHandle managedInputLayerNormWeights(m_LayerNormParameters.m_InputLayerNormWeights); + ManagedConstTensorHandle managedForgetLayerNormWeights(m_LayerNormParameters.m_ForgetLayerNormWeights); + ManagedConstTensorHandle managedCellLayerNormWeights(m_LayerNormParameters.m_CellLayerNormWeights); + ManagedConstTensorHandle managedOutputLayerNormWeights(m_LayerNormParameters.m_OutputLayerNormWeights); ConstTensor inputToInputWeightsTensor; if (m_CifgParameters.m_InputToInputWeights != nullptr) { - ConstTensor inputToInputWeightsTensorCopy(m_CifgParameters.m_InputToInputWeights->GetTensorInfo(), - m_CifgParameters.m_InputToInputWeights->Map(true)); + ConstTensor inputToInputWeightsTensorCopy(managedInputToInputWeights.GetTensorInfo(), + managedInputToInputWeights.Map()); inputToInputWeightsTensor = inputToInputWeightsTensorCopy; inputParams.m_InputToInputWeights = &inputToInputWeightsTensor; } @@ -318,8 +347,8 @@ void QLstmLayer::Accept(ILayerVisitor& visitor) const ConstTensor inputToForgetWeightsTensor; if (m_BasicParameters.m_InputToForgetWeights != nullptr) { - ConstTensor inputToForgetWeightsTensorCopy(m_BasicParameters.m_InputToForgetWeights->GetTensorInfo(), - m_BasicParameters.m_InputToForgetWeights->Map(true)); + ConstTensor inputToForgetWeightsTensorCopy(managedInputToForgetWeights.GetTensorInfo(), + managedInputToForgetWeights.Map()); inputToForgetWeightsTensor = inputToForgetWeightsTensorCopy; inputParams.m_InputToForgetWeights = &inputToForgetWeightsTensor; } @@ -327,8 +356,8 @@ void QLstmLayer::Accept(ILayerVisitor& visitor) const ConstTensor inputToCellWeightsTensor; if (m_BasicParameters.m_InputToCellWeights != nullptr) { - ConstTensor inputToCellWeightsTensorCopy(m_BasicParameters.m_InputToCellWeights->GetTensorInfo(), - m_BasicParameters.m_InputToCellWeights->Map(true)); + ConstTensor inputToCellWeightsTensorCopy(managedInputToCellWeights.GetTensorInfo(), + managedInputToCellWeights.Map()); inputToCellWeightsTensor = inputToCellWeightsTensorCopy; inputParams.m_InputToCellWeights = &inputToCellWeightsTensor; } @@ -336,8 +365,8 @@ void QLstmLayer::Accept(ILayerVisitor& visitor) const ConstTensor inputToOutputWeightsTensor; if (m_BasicParameters.m_InputToOutputWeights != nullptr) { - ConstTensor inputToOutputWeightsTensorCopy(m_BasicParameters.m_InputToOutputWeights->GetTensorInfo(), - m_BasicParameters.m_InputToOutputWeights->Map(true)); + ConstTensor inputToOutputWeightsTensorCopy(managedInputToOutputWeights.GetTensorInfo(), + managedInputToOutputWeights.Map()); inputToOutputWeightsTensor = inputToOutputWeightsTensorCopy; inputParams.m_InputToOutputWeights = &inputToOutputWeightsTensor; } @@ -346,8 +375,8 @@ void QLstmLayer::Accept(ILayerVisitor& visitor) const if (m_CifgParameters.m_RecurrentToInputWeights != nullptr) { ConstTensor recurrentToInputWeightsTensorCopy( - m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo(), - m_CifgParameters.m_RecurrentToInputWeights->Map(true)); + managedRecurrentToInputWeights.GetTensorInfo(), + managedRecurrentToInputWeights.Map()); recurrentToInputWeightsTensor = recurrentToInputWeightsTensorCopy; inputParams.m_RecurrentToInputWeights = &recurrentToInputWeightsTensor; } @@ -356,8 +385,8 @@ void QLstmLayer::Accept(ILayerVisitor& visitor) const if (m_BasicParameters.m_RecurrentToForgetWeights != nullptr) { ConstTensor recurrentToForgetWeightsTensorCopy( - m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo(), - m_BasicParameters.m_RecurrentToForgetWeights->Map(true)); + managedRecurrentToForgetWeights.GetTensorInfo(), + managedRecurrentToForgetWeights.Map()); recurrentToForgetWeightsTensor = recurrentToForgetWeightsTensorCopy; inputParams.m_RecurrentToForgetWeights = &recurrentToForgetWeightsTensor; } @@ -366,8 +395,8 @@ void QLstmLayer::Accept(ILayerVisitor& visitor) const if (m_BasicParameters.m_RecurrentToCellWeights != nullptr) { ConstTensor recurrentToCellWeightsTensorCopy( - m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo(), - m_BasicParameters.m_RecurrentToCellWeights->Map(true)); + managedRecurrentToCellWeights.GetTensorInfo(), + managedRecurrentToCellWeights.Map()); recurrentToCellWeightsTensor = recurrentToCellWeightsTensorCopy; inputParams.m_RecurrentToCellWeights = &recurrentToCellWeightsTensor; } @@ -376,8 +405,8 @@ void QLstmLayer::Accept(ILayerVisitor& visitor) const if (m_BasicParameters.m_RecurrentToOutputWeights != nullptr) { ConstTensor recurrentToOutputWeightsTensorCopy( - m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo(), - m_BasicParameters.m_RecurrentToOutputWeights->Map(true)); + managedRecurrentToOutputWeights.GetTensorInfo(), + managedRecurrentToOutputWeights.Map()); recurrentToOutputWeightsTensor = recurrentToOutputWeightsTensorCopy; inputParams.m_RecurrentToOutputWeights = &recurrentToOutputWeightsTensor; } @@ -385,8 +414,8 @@ void QLstmLayer::Accept(ILayerVisitor& visitor) const ConstTensor cellToInputWeightsTensor; if (m_PeepholeParameters.m_CellToInputWeights != nullptr) { - ConstTensor cellToInputWeightsTensorCopy(m_PeepholeParameters.m_CellToInputWeights->GetTensorInfo(), - m_PeepholeParameters.m_CellToInputWeights->Map(true)); + ConstTensor cellToInputWeightsTensorCopy(managedCellToInputWeights.GetTensorInfo(), + managedCellToInputWeights.Map()); cellToInputWeightsTensor = cellToInputWeightsTensorCopy; inputParams.m_CellToInputWeights = &cellToInputWeightsTensor; } @@ -394,8 +423,8 @@ void QLstmLayer::Accept(ILayerVisitor& visitor) const ConstTensor cellToForgetWeightsTensor; if (m_PeepholeParameters.m_CellToForgetWeights != nullptr) { - ConstTensor cellToForgetWeightsTensorCopy(m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo(), - m_PeepholeParameters.m_CellToForgetWeights->Map(true)); + ConstTensor cellToForgetWeightsTensorCopy(managedCellToForgetWeights.GetTensorInfo(), + managedCellToForgetWeights.Map()); cellToForgetWeightsTensor = cellToForgetWeightsTensorCopy; inputParams.m_CellToForgetWeights = &cellToForgetWeightsTensor; } @@ -403,8 +432,8 @@ void QLstmLayer::Accept(ILayerVisitor& visitor) const ConstTensor cellToOutputWeightsTensor; if (m_PeepholeParameters.m_CellToOutputWeights != nullptr) { - ConstTensor cellToOutputWeightsTensorCopy(m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo(), - m_PeepholeParameters.m_CellToOutputWeights->Map(true)); + ConstTensor cellToOutputWeightsTensorCopy(managedCellToOutputWeights.GetTensorInfo(), + managedCellToOutputWeights.Map()); cellToOutputWeightsTensor = cellToOutputWeightsTensorCopy; inputParams.m_CellToOutputWeights = &cellToOutputWeightsTensor; } @@ -412,8 +441,8 @@ void QLstmLayer::Accept(ILayerVisitor& visitor) const ConstTensor inputGateBiasTensor; if (m_CifgParameters.m_InputGateBias != nullptr) { - ConstTensor inputGateBiasTensorCopy(m_CifgParameters.m_InputGateBias->GetTensorInfo(), - m_CifgParameters.m_InputGateBias->Map(true)); + ConstTensor inputGateBiasTensorCopy(managedInputGateBias.GetTensorInfo(), + managedInputGateBias.Map()); inputGateBiasTensor = inputGateBiasTensorCopy; inputParams.m_InputGateBias = &inputGateBiasTensor; } @@ -421,8 +450,8 @@ void QLstmLayer::Accept(ILayerVisitor& visitor) const ConstTensor forgetGateBiasTensor; if (m_BasicParameters.m_ForgetGateBias != nullptr) { - ConstTensor forgetGateBiasTensorCopy(m_BasicParameters.m_ForgetGateBias->GetTensorInfo(), - m_BasicParameters.m_ForgetGateBias->Map(true)); + ConstTensor forgetGateBiasTensorCopy(managedForgetGateBias.GetTensorInfo(), + managedForgetGateBias.Map()); forgetGateBiasTensor = forgetGateBiasTensorCopy; inputParams.m_ForgetGateBias = &forgetGateBiasTensor; } @@ -430,8 +459,8 @@ void QLstmLayer::Accept(ILayerVisitor& visitor) const ConstTensor cellBiasTensor; if (m_BasicParameters.m_CellBias != nullptr) { - ConstTensor cellBiasTensorCopy(m_BasicParameters.m_CellBias->GetTensorInfo(), - m_BasicParameters.m_CellBias->Map(true)); + ConstTensor cellBiasTensorCopy(managedCellBias.GetTensorInfo(), + managedCellBias.Map()); cellBiasTensor = cellBiasTensorCopy; inputParams.m_CellBias = &cellBiasTensor; } @@ -439,8 +468,8 @@ void QLstmLayer::Accept(ILayerVisitor& visitor) const ConstTensor outputGateBias; if (m_BasicParameters.m_OutputGateBias != nullptr) { - ConstTensor outputGateBiasCopy(m_BasicParameters.m_OutputGateBias->GetTensorInfo(), - m_BasicParameters.m_OutputGateBias->Map(true)); + ConstTensor outputGateBiasCopy(managedOutputGateBias.GetTensorInfo(), + managedOutputGateBias.Map()); outputGateBias = outputGateBiasCopy; inputParams.m_OutputGateBias = &outputGateBias; } @@ -448,8 +477,8 @@ void QLstmLayer::Accept(ILayerVisitor& visitor) const ConstTensor projectionWeightsTensor; if (m_ProjectionParameters.m_ProjectionWeights != nullptr) { - ConstTensor projectionWeightsTensorCopy(m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo(), - m_ProjectionParameters.m_ProjectionWeights->Map(true)); + ConstTensor projectionWeightsTensorCopy(managedProjectionWeights.GetTensorInfo(), + managedProjectionWeights.Map()); projectionWeightsTensor = projectionWeightsTensorCopy; inputParams.m_ProjectionWeights = &projectionWeightsTensor; } @@ -457,8 +486,8 @@ void QLstmLayer::Accept(ILayerVisitor& visitor) const ConstTensor projectionBiasTensor; if (m_ProjectionParameters.m_ProjectionBias != nullptr) { - ConstTensor projectionBiasTensorCopy(m_ProjectionParameters.m_ProjectionBias->GetTensorInfo(), - m_ProjectionParameters.m_ProjectionBias->Map(true)); + ConstTensor projectionBiasTensorCopy(managedProjectionBias.GetTensorInfo(), + managedProjectionBias.Map()); projectionBiasTensor = projectionBiasTensorCopy; inputParams.m_ProjectionBias = &projectionBiasTensor; } @@ -466,8 +495,8 @@ void QLstmLayer::Accept(ILayerVisitor& visitor) const ConstTensor inputLayerNormTensor; if (m_LayerNormParameters.m_InputLayerNormWeights != nullptr) { - ConstTensor inputLayerNormTensorCopy(m_LayerNormParameters.m_InputLayerNormWeights->GetTensorInfo(), - m_LayerNormParameters.m_InputLayerNormWeights->Map(true)); + ConstTensor inputLayerNormTensorCopy(managedInputLayerNormWeights.GetTensorInfo(), + managedInputLayerNormWeights.Map()); inputLayerNormTensor = inputLayerNormTensorCopy; inputParams.m_InputLayerNormWeights = &inputLayerNormTensor; } @@ -475,8 +504,8 @@ void QLstmLayer::Accept(ILayerVisitor& visitor) const ConstTensor forgetLayerNormTensor; if (m_LayerNormParameters.m_ForgetLayerNormWeights != nullptr) { - ConstTensor forgetLayerNormTensorCopy(m_LayerNormParameters.m_ForgetLayerNormWeights->GetTensorInfo(), - m_LayerNormParameters.m_ForgetLayerNormWeights->Map(true)); + ConstTensor forgetLayerNormTensorCopy(managedForgetLayerNormWeights.GetTensorInfo(), + managedForgetLayerNormWeights.Map()); forgetLayerNormTensor = forgetLayerNormTensorCopy; inputParams.m_ForgetLayerNormWeights = &forgetLayerNormTensor; } @@ -484,8 +513,8 @@ void QLstmLayer::Accept(ILayerVisitor& visitor) const ConstTensor cellLayerNormTensor; if (m_LayerNormParameters.m_CellLayerNormWeights != nullptr) { - ConstTensor cellLayerNormTensorCopy(m_LayerNormParameters.m_CellLayerNormWeights->GetTensorInfo(), - m_LayerNormParameters.m_CellLayerNormWeights->Map(true)); + ConstTensor cellLayerNormTensorCopy(managedCellLayerNormWeights.GetTensorInfo(), + managedCellLayerNormWeights.Map()); cellLayerNormTensor = cellLayerNormTensorCopy; inputParams.m_CellLayerNormWeights = &cellLayerNormTensor; } @@ -493,8 +522,8 @@ void QLstmLayer::Accept(ILayerVisitor& visitor) const ConstTensor outputLayerNormTensor; if (m_LayerNormParameters.m_OutputLayerNormWeights != nullptr) { - ConstTensor outputLayerNormTensorCopy(m_LayerNormParameters.m_OutputLayerNormWeights->GetTensorInfo(), - m_LayerNormParameters.m_OutputLayerNormWeights->Map(true)); + ConstTensor outputLayerNormTensorCopy(managedOutputLayerNormWeights.GetTensorInfo(), + managedOutputLayerNormWeights.Map()); outputLayerNormTensor = outputLayerNormTensorCopy; inputParams.m_OutputLayerNormWeights = &outputLayerNormTensor; } @@ -507,124 +536,153 @@ void QLstmLayer::Accept(ILayerVisitor& visitor) const void QLstmLayer::ExecuteStrategy(IStrategy& strategy) const { std::vector<ConstTensor> constTensors; + ManagedConstTensorHandle managedInputToForgetWeights(m_BasicParameters.m_InputToForgetWeights); + ManagedConstTensorHandle managedInputToCellWeights(m_BasicParameters.m_InputToCellWeights); + ManagedConstTensorHandle managedInputToOutputWeights(m_BasicParameters.m_InputToOutputWeights); + ManagedConstTensorHandle managedRecurrentToForgetWeights(m_BasicParameters.m_RecurrentToForgetWeights); + ManagedConstTensorHandle managedRecurrentToCellWeights(m_BasicParameters.m_RecurrentToCellWeights); + ManagedConstTensorHandle managedRecurrentToOutputWeights(m_BasicParameters.m_RecurrentToOutputWeights); + ManagedConstTensorHandle managedForgetGateBias(m_BasicParameters.m_ForgetGateBias); + ManagedConstTensorHandle managedCellBias(m_BasicParameters.m_CellBias); + ManagedConstTensorHandle managedOutputGateBias(m_BasicParameters.m_OutputGateBias); + + // Cifg parameters + ManagedConstTensorHandle managedInputToInputWeights(m_CifgParameters.m_InputToInputWeights); + ManagedConstTensorHandle managedRecurrentToInputWeights(m_CifgParameters.m_RecurrentToInputWeights); + ManagedConstTensorHandle managedInputGateBias(m_CifgParameters.m_InputGateBias); + + // Projection parameters + ManagedConstTensorHandle managedProjectionWeights(m_ProjectionParameters.m_ProjectionWeights); + ManagedConstTensorHandle managedProjectionBias(m_ProjectionParameters.m_ProjectionBias); + + // Peephole parameters + ManagedConstTensorHandle managedCellToInputWeights(m_PeepholeParameters.m_CellToInputWeights); + ManagedConstTensorHandle managedCellToForgetWeights(m_PeepholeParameters.m_CellToForgetWeights); + ManagedConstTensorHandle managedCellToOutputWeights(m_PeepholeParameters.m_CellToOutputWeights); + + // Layer normalisation parameters + ManagedConstTensorHandle managedInputLayerNormWeights(m_LayerNormParameters.m_InputLayerNormWeights); + ManagedConstTensorHandle managedForgetLayerNormWeights(m_LayerNormParameters.m_ForgetLayerNormWeights); + ManagedConstTensorHandle managedCellLayerNormWeights(m_LayerNormParameters.m_CellLayerNormWeights); + ManagedConstTensorHandle managedOutputLayerNormWeights(m_LayerNormParameters.m_OutputLayerNormWeights); // First add mandatory/basic parameters if (m_BasicParameters.m_InputToForgetWeights != nullptr) { - constTensors.emplace_back(ConstTensor(m_BasicParameters.m_InputToForgetWeights->GetTensorInfo(), - m_BasicParameters.m_InputToForgetWeights->Map(true))); + constTensors.emplace_back(ConstTensor(managedInputToForgetWeights.GetTensorInfo(), + managedInputToForgetWeights.Map())); } if (m_BasicParameters.m_InputToCellWeights != nullptr) { - constTensors.emplace_back(ConstTensor(m_BasicParameters.m_InputToCellWeights->GetTensorInfo(), - m_BasicParameters.m_InputToCellWeights->Map(true))); + constTensors.emplace_back(ConstTensor(managedInputToCellWeights.GetTensorInfo(), + managedInputToCellWeights.Map())); } if (m_BasicParameters.m_InputToOutputWeights != nullptr) { - constTensors.emplace_back(ConstTensor(m_BasicParameters.m_InputToOutputWeights->GetTensorInfo(), - m_BasicParameters.m_InputToOutputWeights->Map(true))); + constTensors.emplace_back(ConstTensor(managedInputToOutputWeights.GetTensorInfo(), + managedInputToOutputWeights.Map())); } if (m_BasicParameters.m_RecurrentToForgetWeights != nullptr) { constTensors.emplace_back(ConstTensor( - m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo(), - m_BasicParameters.m_RecurrentToForgetWeights->Map(true))); + managedRecurrentToForgetWeights.GetTensorInfo(), + managedRecurrentToForgetWeights.Map())); } if (m_BasicParameters.m_RecurrentToCellWeights != nullptr) { constTensors.emplace_back(ConstTensor( - m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo(), - m_BasicParameters.m_RecurrentToCellWeights->Map(true))); + managedRecurrentToCellWeights.GetTensorInfo(), + managedRecurrentToCellWeights.Map())); } if (m_BasicParameters.m_RecurrentToOutputWeights != nullptr) { constTensors.emplace_back(ConstTensor( - m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo(), - m_BasicParameters.m_RecurrentToOutputWeights->Map(true))); + managedRecurrentToOutputWeights.GetTensorInfo(), + managedRecurrentToOutputWeights.Map())); } if (m_BasicParameters.m_ForgetGateBias != nullptr) { - constTensors.emplace_back(ConstTensor(m_BasicParameters.m_ForgetGateBias->GetTensorInfo(), - m_BasicParameters.m_ForgetGateBias->Map(true))); + constTensors.emplace_back(ConstTensor(managedForgetGateBias.GetTensorInfo(), + managedForgetGateBias.Map())); } if (m_BasicParameters.m_CellBias != nullptr) { - constTensors.emplace_back(ConstTensor(m_BasicParameters.m_CellBias->GetTensorInfo(), - m_BasicParameters.m_CellBias->Map(true))); + constTensors.emplace_back(ConstTensor(managedCellBias.GetTensorInfo(), + managedCellBias.Map())); } if (m_BasicParameters.m_OutputGateBias != nullptr) { - constTensors.emplace_back(ConstTensor(m_BasicParameters.m_OutputGateBias->GetTensorInfo(), - m_BasicParameters.m_OutputGateBias->Map(true))); + constTensors.emplace_back(ConstTensor(managedOutputGateBias.GetTensorInfo(), + managedOutputGateBias.Map())); } // Add cifig parameters if (m_CifgParameters.m_InputToInputWeights != nullptr) { - constTensors.emplace_back(ConstTensor(m_CifgParameters.m_InputToInputWeights->GetTensorInfo(), - m_CifgParameters.m_InputToInputWeights->Map(true))); + constTensors.emplace_back(ConstTensor(managedInputToInputWeights.GetTensorInfo(), + managedInputToInputWeights.Map())); } if (m_CifgParameters.m_RecurrentToInputWeights != nullptr) { constTensors.emplace_back(ConstTensor( - m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo(), - m_CifgParameters.m_RecurrentToInputWeights->Map(true))); + managedRecurrentToInputWeights.GetTensorInfo(), + managedRecurrentToInputWeights.Map())); } if (m_CifgParameters.m_InputGateBias != nullptr) { - constTensors.emplace_back(ConstTensor(m_CifgParameters.m_InputGateBias->GetTensorInfo(), - m_CifgParameters.m_InputGateBias->Map(true))); + constTensors.emplace_back(ConstTensor(managedInputGateBias.GetTensorInfo(), + managedInputGateBias.Map())); } // Add peephole parameters if (m_PeepholeParameters.m_CellToInputWeights != nullptr) { - constTensors.emplace_back(ConstTensor(m_PeepholeParameters.m_CellToInputWeights->GetTensorInfo(), - m_PeepholeParameters.m_CellToInputWeights->Map(true))); + constTensors.emplace_back(ConstTensor(managedCellToInputWeights.GetTensorInfo(), + managedCellToInputWeights.Map())); } if (m_PeepholeParameters.m_CellToForgetWeights != nullptr) { - constTensors.emplace_back(ConstTensor(m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo(), - m_PeepholeParameters.m_CellToForgetWeights->Map(true))); + constTensors.emplace_back(ConstTensor(managedCellToForgetWeights.GetTensorInfo(), + managedCellToForgetWeights.Map())); } if (m_PeepholeParameters.m_CellToOutputWeights != nullptr) { - constTensors.emplace_back(ConstTensor(m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo(), - m_PeepholeParameters.m_CellToOutputWeights->Map(true))); + constTensors.emplace_back(ConstTensor(managedCellToOutputWeights.GetTensorInfo(), + managedCellToOutputWeights.Map())); } // Add projection parameters if (m_ProjectionParameters.m_ProjectionWeights != nullptr) { - constTensors.emplace_back(ConstTensor(m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo(), - m_ProjectionParameters.m_ProjectionWeights->Map(true))); + constTensors.emplace_back(ConstTensor(managedProjectionWeights.GetTensorInfo(), + managedProjectionWeights.Map())); } if (m_ProjectionParameters.m_ProjectionBias != nullptr) { - constTensors.emplace_back(ConstTensor(m_ProjectionParameters.m_ProjectionBias->GetTensorInfo(), - m_ProjectionParameters.m_ProjectionBias->Map(true))); + constTensors.emplace_back(ConstTensor(managedProjectionBias.GetTensorInfo(), + managedProjectionBias.Map())); } // Add norm parameters if (m_LayerNormParameters.m_InputLayerNormWeights != nullptr) { - constTensors.emplace_back(ConstTensor(m_LayerNormParameters.m_InputLayerNormWeights->GetTensorInfo(), - m_LayerNormParameters.m_InputLayerNormWeights->Map(true))); + constTensors.emplace_back(ConstTensor(managedInputLayerNormWeights.GetTensorInfo(), + managedInputLayerNormWeights.Map())); } if (m_LayerNormParameters.m_ForgetLayerNormWeights != nullptr) { - constTensors.emplace_back(ConstTensor(m_LayerNormParameters.m_ForgetLayerNormWeights->GetTensorInfo(), - m_LayerNormParameters.m_ForgetLayerNormWeights->Map(true))); + constTensors.emplace_back(ConstTensor(managedForgetLayerNormWeights.GetTensorInfo(), + managedForgetLayerNormWeights.Map())); } if (m_LayerNormParameters.m_CellLayerNormWeights != nullptr) { - constTensors.emplace_back(ConstTensor(m_LayerNormParameters.m_CellLayerNormWeights->GetTensorInfo(), - m_LayerNormParameters.m_CellLayerNormWeights->Map(true))); + constTensors.emplace_back(ConstTensor(managedCellLayerNormWeights.GetTensorInfo(), + managedCellLayerNormWeights.Map())); } if (m_LayerNormParameters.m_OutputLayerNormWeights != nullptr) { - constTensors.emplace_back(ConstTensor(m_LayerNormParameters.m_OutputLayerNormWeights->GetTensorInfo(), - m_LayerNormParameters.m_OutputLayerNormWeights->Map(true))); + constTensors.emplace_back(ConstTensor(managedOutputLayerNormWeights.GetTensorInfo(), + managedOutputLayerNormWeights.Map())); } strategy.ExecuteStrategy(this, GetParameters(), constTensors, GetName()); } |