diff options
Diffstat (limited to 'src/armnn/layers/LstmLayer.cpp')
-rw-r--r-- | src/armnn/layers/LstmLayer.cpp | 228 |
1 files changed, 144 insertions, 84 deletions
diff --git a/src/armnn/layers/LstmLayer.cpp b/src/armnn/layers/LstmLayer.cpp index 0eeb2f8eab..403d911e7e 100644 --- a/src/armnn/layers/LstmLayer.cpp +++ b/src/armnn/layers/LstmLayer.cpp @@ -303,35 +303,65 @@ Layer::ConstantTensors LstmLayer::GetConstantTensorsByRef() void LstmLayer::Accept(ILayerVisitor& visitor) const { LstmInputParams inputParams; + ManagedConstTensorHandle managedInputToForgetWeights(m_BasicParameters.m_InputToForgetWeights); + ManagedConstTensorHandle managedInputToCellWeights(m_BasicParameters.m_InputToCellWeights); + ManagedConstTensorHandle managedInputToOutputWeights(m_BasicParameters.m_InputToOutputWeights); + ManagedConstTensorHandle managedRecurrentToForgetWeights(m_BasicParameters.m_RecurrentToForgetWeights); + ManagedConstTensorHandle managedRecurrentToCellWeights(m_BasicParameters.m_RecurrentToCellWeights); + ManagedConstTensorHandle managedRecurrentToOutputWeights(m_BasicParameters.m_RecurrentToOutputWeights); + ManagedConstTensorHandle managedForgetGateBias(m_BasicParameters.m_ForgetGateBias); + ManagedConstTensorHandle managedCellBias(m_BasicParameters.m_CellBias); + ManagedConstTensorHandle managedOutputGateBias(m_BasicParameters.m_OutputGateBias); + + // Cifg parameters + ManagedConstTensorHandle managedInputToInputWeights(m_CifgParameters.m_InputToInputWeights); + ManagedConstTensorHandle managedRecurrentToInputWeights(m_CifgParameters.m_RecurrentToInputWeights); + ManagedConstTensorHandle managedInputGateBias(m_CifgParameters.m_InputGateBias); + + // Projection parameters + ManagedConstTensorHandle managedProjectionWeights(m_ProjectionParameters.m_ProjectionWeights); + ManagedConstTensorHandle managedProjectionBias(m_ProjectionParameters.m_ProjectionBias); + + // Peephole parameters + ManagedConstTensorHandle managedCellToInputWeights(m_PeepholeParameters.m_CellToInputWeights); + ManagedConstTensorHandle managedCellToForgetWeights(m_PeepholeParameters.m_CellToForgetWeights); + ManagedConstTensorHandle managedCellToOutputWeights(m_PeepholeParameters.m_CellToOutputWeights); + + // Layer normalisation parameters + ManagedConstTensorHandle managedInputLayerNormWeights(m_LayerNormParameters.m_InputLayerNormWeights); + ManagedConstTensorHandle managedForgetLayerNormWeights(m_LayerNormParameters.m_ForgetLayerNormWeights); + ManagedConstTensorHandle managedCellLayerNormWeights(m_LayerNormParameters.m_CellLayerNormWeights); + ManagedConstTensorHandle managedOutputLayerNormWeights(m_LayerNormParameters.m_OutputLayerNormWeights); + ConstTensor inputToInputWeightsTensor; if (m_CifgParameters.m_InputToInputWeights != nullptr) { - ConstTensor inputToInputWeightsTensorCopy(m_CifgParameters.m_InputToInputWeights->GetTensorInfo(), - m_CifgParameters.m_InputToInputWeights->Map(true)); + ConstTensor inputToInputWeightsTensorCopy(managedInputToInputWeights.GetTensorInfo(), + managedInputToInputWeights.Map()); inputToInputWeightsTensor = inputToInputWeightsTensorCopy; inputParams.m_InputToInputWeights = &inputToInputWeightsTensor; } ConstTensor inputToForgetWeightsTensor; if (m_BasicParameters.m_InputToForgetWeights != nullptr) { - ConstTensor inputToForgetWeightsTensorCopy(m_BasicParameters.m_InputToForgetWeights->GetTensorInfo(), - m_BasicParameters.m_InputToForgetWeights->Map(true)); + ConstTensor inputToForgetWeightsTensorCopy(managedInputToForgetWeights.GetTensorInfo(), + managedInputToForgetWeights.Map()); inputToForgetWeightsTensor = inputToForgetWeightsTensorCopy; inputParams.m_InputToForgetWeights = &inputToForgetWeightsTensor; } ConstTensor inputToCellWeightsTensor; if (m_BasicParameters.m_InputToCellWeights != nullptr) { - ConstTensor inputToCellWeightsTensorCopy(m_BasicParameters.m_InputToCellWeights->GetTensorInfo(), - m_BasicParameters.m_InputToCellWeights->Map(true)); + ConstTensor inputToCellWeightsTensorCopy(managedInputToCellWeights.GetTensorInfo(), + managedInputToCellWeights.Map()); inputToCellWeightsTensor = inputToCellWeightsTensorCopy; inputParams.m_InputToCellWeights = &inputToCellWeightsTensor; } ConstTensor inputToOutputWeightsTensor; if (m_BasicParameters.m_InputToOutputWeights != nullptr) { - ConstTensor inputToOutputWeightsTensorCopy(m_BasicParameters.m_InputToOutputWeights->GetTensorInfo(), - m_BasicParameters.m_InputToOutputWeights->Map(true)); + ConstTensor inputToOutputWeightsTensorCopy(managedInputToOutputWeights.GetTensorInfo(), + managedInputToOutputWeights.Map()); inputToOutputWeightsTensor = inputToOutputWeightsTensorCopy; inputParams.m_InputToOutputWeights = &inputToOutputWeightsTensor; } @@ -339,8 +369,8 @@ void LstmLayer::Accept(ILayerVisitor& visitor) const if (m_CifgParameters.m_RecurrentToInputWeights != nullptr) { ConstTensor recurrentToInputWeightsTensorCopy( - m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo(), - m_CifgParameters.m_RecurrentToInputWeights->Map(true)); + managedRecurrentToInputWeights.GetTensorInfo(), + managedRecurrentToInputWeights.Map()); recurrentToInputWeightsTensor = recurrentToInputWeightsTensorCopy; inputParams.m_RecurrentToInputWeights = &recurrentToInputWeightsTensor; } @@ -348,8 +378,8 @@ void LstmLayer::Accept(ILayerVisitor& visitor) const if (m_BasicParameters.m_RecurrentToForgetWeights != nullptr) { ConstTensor recurrentToForgetWeightsTensorCopy( - m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo(), - m_BasicParameters.m_RecurrentToForgetWeights->Map(true)); + managedRecurrentToForgetWeights.GetTensorInfo(), + managedRecurrentToForgetWeights.Map()); recurrentToForgetWeightsTensor = recurrentToForgetWeightsTensorCopy; inputParams.m_RecurrentToForgetWeights = &recurrentToForgetWeightsTensor; } @@ -357,8 +387,8 @@ void LstmLayer::Accept(ILayerVisitor& visitor) const if (m_BasicParameters.m_RecurrentToCellWeights != nullptr) { ConstTensor recurrentToCellWeightsTensorCopy( - m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo(), - m_BasicParameters.m_RecurrentToCellWeights->Map(true)); + managedRecurrentToCellWeights.GetTensorInfo(), + managedRecurrentToCellWeights.Map()); recurrentToCellWeightsTensor = recurrentToCellWeightsTensorCopy; inputParams.m_RecurrentToCellWeights = &recurrentToCellWeightsTensor; } @@ -366,112 +396,112 @@ void LstmLayer::Accept(ILayerVisitor& visitor) const if (m_BasicParameters.m_RecurrentToOutputWeights != nullptr) { ConstTensor recurrentToOutputWeightsTensorCopy( - m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo(), - m_BasicParameters.m_RecurrentToOutputWeights->Map(true)); + managedRecurrentToOutputWeights.GetTensorInfo(), + managedRecurrentToOutputWeights.Map()); recurrentToOutputWeightsTensor = recurrentToOutputWeightsTensorCopy; inputParams.m_RecurrentToOutputWeights = &recurrentToOutputWeightsTensor; } ConstTensor cellToInputWeightsTensor; if (m_PeepholeParameters.m_CellToInputWeights != nullptr) { - ConstTensor cellToInputWeightsTensorCopy(m_PeepholeParameters.m_CellToInputWeights->GetTensorInfo(), - m_PeepholeParameters.m_CellToInputWeights->Map(true)); + ConstTensor cellToInputWeightsTensorCopy(managedCellToInputWeights.GetTensorInfo(), + managedCellToInputWeights.Map()); cellToInputWeightsTensor = cellToInputWeightsTensorCopy; inputParams.m_CellToInputWeights = &cellToInputWeightsTensor; } ConstTensor cellToForgetWeightsTensor; if (m_PeepholeParameters.m_CellToForgetWeights != nullptr) { - ConstTensor cellToForgetWeightsTensorCopy(m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo(), - m_PeepholeParameters.m_CellToForgetWeights->Map(true)); + ConstTensor cellToForgetWeightsTensorCopy(managedCellToForgetWeights.GetTensorInfo(), + managedCellToForgetWeights.Map()); cellToForgetWeightsTensor = cellToForgetWeightsTensorCopy; inputParams.m_CellToForgetWeights = &cellToForgetWeightsTensor; } ConstTensor cellToOutputWeightsTensor; if (m_PeepholeParameters.m_CellToOutputWeights != nullptr) { - ConstTensor cellToOutputWeightsTensorCopy(m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo(), - m_PeepholeParameters.m_CellToOutputWeights->Map(true)); + ConstTensor cellToOutputWeightsTensorCopy(managedCellToOutputWeights.GetTensorInfo(), + managedCellToOutputWeights.Map()); cellToOutputWeightsTensor = cellToOutputWeightsTensorCopy; inputParams.m_CellToOutputWeights = &cellToOutputWeightsTensor; } ConstTensor inputGateBiasTensor; if (m_CifgParameters.m_InputGateBias != nullptr) { - ConstTensor inputGateBiasTensorCopy(m_CifgParameters.m_InputGateBias->GetTensorInfo(), - m_CifgParameters.m_InputGateBias->Map(true)); + ConstTensor inputGateBiasTensorCopy(managedInputGateBias.GetTensorInfo(), + managedInputGateBias.Map()); inputGateBiasTensor = inputGateBiasTensorCopy; inputParams.m_InputGateBias = &inputGateBiasTensor; } ConstTensor forgetGateBiasTensor; if (m_BasicParameters.m_ForgetGateBias != nullptr) { - ConstTensor forgetGateBiasTensorCopy(m_BasicParameters.m_ForgetGateBias->GetTensorInfo(), - m_BasicParameters.m_ForgetGateBias->Map(true)); + ConstTensor forgetGateBiasTensorCopy(managedForgetGateBias.GetTensorInfo(), + managedForgetGateBias.Map()); forgetGateBiasTensor = forgetGateBiasTensorCopy; inputParams.m_ForgetGateBias = &forgetGateBiasTensor; } ConstTensor cellBiasTensor; if (m_BasicParameters.m_CellBias != nullptr) { - ConstTensor cellBiasTensorCopy(m_BasicParameters.m_CellBias->GetTensorInfo(), - m_BasicParameters.m_CellBias->Map(true)); + ConstTensor cellBiasTensorCopy(managedCellBias.GetTensorInfo(), + managedCellBias.Map()); cellBiasTensor = cellBiasTensorCopy; inputParams.m_CellBias = &cellBiasTensor; } ConstTensor outputGateBias; if (m_BasicParameters.m_OutputGateBias != nullptr) { - ConstTensor outputGateBiasCopy(m_BasicParameters.m_OutputGateBias->GetTensorInfo(), - m_BasicParameters.m_OutputGateBias->Map(true)); + ConstTensor outputGateBiasCopy(managedOutputGateBias.GetTensorInfo(), + managedOutputGateBias.Map()); outputGateBias = outputGateBiasCopy; inputParams.m_OutputGateBias = &outputGateBias; } ConstTensor projectionWeightsTensor; if (m_ProjectionParameters.m_ProjectionWeights != nullptr) { - ConstTensor projectionWeightsTensorCopy(m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo(), - m_ProjectionParameters.m_ProjectionWeights->Map(true)); + ConstTensor projectionWeightsTensorCopy(managedProjectionWeights.GetTensorInfo(), + managedProjectionWeights.Map()); projectionWeightsTensor = projectionWeightsTensorCopy; inputParams.m_ProjectionWeights = &projectionWeightsTensor; } ConstTensor projectionBiasTensor; if (m_ProjectionParameters.m_ProjectionBias != nullptr) { - ConstTensor projectionBiasTensorCopy(m_ProjectionParameters.m_ProjectionBias->GetTensorInfo(), - m_ProjectionParameters.m_ProjectionBias->Map(true)); + ConstTensor projectionBiasTensorCopy(managedProjectionBias.GetTensorInfo(), + managedProjectionBias.Map()); projectionBiasTensor = projectionBiasTensorCopy; inputParams.m_ProjectionBias = &projectionBiasTensor; } ConstTensor inputLayerNormTensor; if (m_LayerNormParameters.m_InputLayerNormWeights != nullptr) { - ConstTensor inputLayerNormTensorCopy(m_LayerNormParameters.m_InputLayerNormWeights->GetTensorInfo(), - m_LayerNormParameters.m_InputLayerNormWeights->Map(true)); + ConstTensor inputLayerNormTensorCopy(managedInputLayerNormWeights.GetTensorInfo(), + managedInputLayerNormWeights.Map()); inputLayerNormTensor = inputLayerNormTensorCopy; inputParams.m_InputLayerNormWeights = &inputLayerNormTensor; } ConstTensor forgetLayerNormTensor; if (m_LayerNormParameters.m_ForgetLayerNormWeights != nullptr) { - ConstTensor forgetLayerNormTensorCopy(m_LayerNormParameters.m_ForgetLayerNormWeights->GetTensorInfo(), - m_LayerNormParameters.m_ForgetLayerNormWeights->Map(true)); + ConstTensor forgetLayerNormTensorCopy(managedForgetLayerNormWeights.GetTensorInfo(), + managedForgetLayerNormWeights.Map()); forgetLayerNormTensor = forgetLayerNormTensorCopy; inputParams.m_ForgetLayerNormWeights = &forgetLayerNormTensor; } ConstTensor cellLayerNormTensor; if (m_LayerNormParameters.m_CellLayerNormWeights != nullptr) { - ConstTensor cellLayerNormTensorCopy(m_LayerNormParameters.m_CellLayerNormWeights->GetTensorInfo(), - m_LayerNormParameters.m_CellLayerNormWeights->Map(true)); + ConstTensor cellLayerNormTensorCopy(managedCellLayerNormWeights.GetTensorInfo(), + managedCellLayerNormWeights.Map()); cellLayerNormTensor = cellLayerNormTensorCopy; inputParams.m_CellLayerNormWeights = &cellLayerNormTensor; } ConstTensor outputLayerNormTensor; if (m_LayerNormParameters.m_OutputLayerNormWeights != nullptr) { - ConstTensor outputLayerNormTensorCopy(m_LayerNormParameters.m_OutputLayerNormWeights->GetTensorInfo(), - m_LayerNormParameters.m_OutputLayerNormWeights->Map(true)); + ConstTensor outputLayerNormTensorCopy(managedOutputLayerNormWeights.GetTensorInfo(), + managedOutputLayerNormWeights.Map()); outputLayerNormTensor = outputLayerNormTensorCopy; inputParams.m_OutputLayerNormWeights = &outputLayerNormTensor; } @@ -486,54 +516,84 @@ void LstmLayer::ExecuteStrategy(IStrategy& strategy) const LstmDescriptor descriptor = GetParameters(); + ManagedConstTensorHandle managedInputToForgetWeights(m_BasicParameters.m_InputToForgetWeights); + ManagedConstTensorHandle managedInputToCellWeights(m_BasicParameters.m_InputToCellWeights); + ManagedConstTensorHandle managedInputToOutputWeights(m_BasicParameters.m_InputToOutputWeights); + ManagedConstTensorHandle managedRecurrentToForgetWeights(m_BasicParameters.m_RecurrentToForgetWeights); + ManagedConstTensorHandle managedRecurrentToCellWeights(m_BasicParameters.m_RecurrentToCellWeights); + ManagedConstTensorHandle managedRecurrentToOutputWeights(m_BasicParameters.m_RecurrentToOutputWeights); + ManagedConstTensorHandle managedForgetGateBias(m_BasicParameters.m_ForgetGateBias); + ManagedConstTensorHandle managedCellBias(m_BasicParameters.m_CellBias); + ManagedConstTensorHandle managedOutputGateBias(m_BasicParameters.m_OutputGateBias); + + // Cifg parameters + ManagedConstTensorHandle managedInputToInputWeights(m_CifgParameters.m_InputToInputWeights); + ManagedConstTensorHandle managedRecurrentToInputWeights(m_CifgParameters.m_RecurrentToInputWeights); + ManagedConstTensorHandle managedInputGateBias(m_CifgParameters.m_InputGateBias); + + // Projection parameters + ManagedConstTensorHandle managedProjectionWeights(m_ProjectionParameters.m_ProjectionWeights); + ManagedConstTensorHandle managedProjectionBias(m_ProjectionParameters.m_ProjectionBias); + + // Peephole parameters + ManagedConstTensorHandle managedCellToInputWeights(m_PeepholeParameters.m_CellToInputWeights); + ManagedConstTensorHandle managedCellToForgetWeights(m_PeepholeParameters.m_CellToForgetWeights); + ManagedConstTensorHandle managedCellToOutputWeights(m_PeepholeParameters.m_CellToOutputWeights); + + // Layer normalisation parameters + ManagedConstTensorHandle managedInputLayerNormWeights(m_LayerNormParameters.m_InputLayerNormWeights); + ManagedConstTensorHandle managedForgetLayerNormWeights(m_LayerNormParameters.m_ForgetLayerNormWeights); + ManagedConstTensorHandle managedCellLayerNormWeights(m_LayerNormParameters.m_CellLayerNormWeights); + ManagedConstTensorHandle managedOutputLayerNormWeights(m_LayerNormParameters.m_OutputLayerNormWeights); + // First add mandatory/basic parameters if (m_BasicParameters.m_InputToForgetWeights != nullptr) { - constTensors.emplace_back(ConstTensor(m_BasicParameters.m_InputToForgetWeights->GetTensorInfo(), - m_BasicParameters.m_InputToForgetWeights->Map(true))); + constTensors.emplace_back(ConstTensor(managedInputToForgetWeights.GetTensorInfo(), + managedInputToForgetWeights.Map())); } if (m_BasicParameters.m_InputToCellWeights != nullptr) { - constTensors.emplace_back(ConstTensor(m_BasicParameters.m_InputToCellWeights->GetTensorInfo(), - m_BasicParameters.m_InputToCellWeights->Map(true))); + constTensors.emplace_back(ConstTensor(managedInputToCellWeights.GetTensorInfo(), + managedInputToCellWeights.Map())); } if (m_BasicParameters.m_InputToOutputWeights != nullptr) { - constTensors.emplace_back(ConstTensor(m_BasicParameters.m_InputToOutputWeights->GetTensorInfo(), - m_BasicParameters.m_InputToOutputWeights->Map(true))); + constTensors.emplace_back(ConstTensor(managedInputToOutputWeights.GetTensorInfo(), + managedInputToOutputWeights.Map())); } if (m_BasicParameters.m_RecurrentToForgetWeights != nullptr) { constTensors.emplace_back(ConstTensor( - m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo(), - m_BasicParameters.m_RecurrentToForgetWeights->Map(true))); + managedRecurrentToForgetWeights.GetTensorInfo(), + managedRecurrentToForgetWeights.Map())); } if (m_BasicParameters.m_RecurrentToCellWeights != nullptr) { constTensors.emplace_back(ConstTensor( - m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo(), - m_BasicParameters.m_RecurrentToCellWeights->Map(true))); + managedRecurrentToCellWeights.GetTensorInfo(), + managedRecurrentToCellWeights.Map())); } if (m_BasicParameters.m_RecurrentToOutputWeights != nullptr) { constTensors.emplace_back(ConstTensor( - m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo(), - m_BasicParameters.m_RecurrentToOutputWeights->Map(true))); + managedRecurrentToOutputWeights.GetTensorInfo(), + managedRecurrentToOutputWeights.Map())); } if (m_BasicParameters.m_ForgetGateBias != nullptr) { - constTensors.emplace_back(ConstTensor(m_BasicParameters.m_ForgetGateBias->GetTensorInfo(), - m_BasicParameters.m_ForgetGateBias->Map(true))); + constTensors.emplace_back(ConstTensor(managedForgetGateBias.GetTensorInfo(), + managedForgetGateBias.Map())); } if (m_BasicParameters.m_CellBias != nullptr) { - constTensors.emplace_back(ConstTensor(m_BasicParameters.m_CellBias->GetTensorInfo(), - m_BasicParameters.m_CellBias->Map(true))); + constTensors.emplace_back(ConstTensor(managedCellBias.GetTensorInfo(), + managedCellBias.Map())); } if (m_BasicParameters.m_OutputGateBias != nullptr) { - constTensors.emplace_back(ConstTensor(m_BasicParameters.m_OutputGateBias->GetTensorInfo(), - m_BasicParameters.m_OutputGateBias->Map(true))); + constTensors.emplace_back(ConstTensor(managedOutputGateBias.GetTensorInfo(), + managedOutputGateBias.Map())); } // Add cifg parameters @@ -541,19 +601,19 @@ void LstmLayer::ExecuteStrategy(IStrategy& strategy) const { if (m_CifgParameters.m_InputToInputWeights != nullptr) { - constTensors.emplace_back(ConstTensor(m_CifgParameters.m_InputToInputWeights->GetTensorInfo(), - m_CifgParameters.m_InputToInputWeights->Map(true))); + constTensors.emplace_back(ConstTensor(managedInputToInputWeights.GetTensorInfo(), + managedInputToInputWeights.Map())); } if (m_CifgParameters.m_RecurrentToInputWeights != nullptr) { constTensors.emplace_back(ConstTensor( - m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo(), - m_CifgParameters.m_RecurrentToInputWeights->Map(true))); + managedRecurrentToInputWeights.GetTensorInfo(), + managedRecurrentToInputWeights.Map())); } if (m_CifgParameters.m_InputGateBias != nullptr) { - constTensors.emplace_back(ConstTensor(m_CifgParameters.m_InputGateBias->GetTensorInfo(), - m_CifgParameters.m_InputGateBias->Map(true))); + constTensors.emplace_back(ConstTensor(managedInputGateBias.GetTensorInfo(), + managedInputGateBias.Map())); } } @@ -564,19 +624,19 @@ void LstmLayer::ExecuteStrategy(IStrategy& strategy) const { if (m_PeepholeParameters.m_CellToInputWeights != nullptr) { - constTensors.emplace_back(ConstTensor(m_PeepholeParameters.m_CellToInputWeights->GetTensorInfo(), - m_PeepholeParameters.m_CellToInputWeights->Map(true))); + constTensors.emplace_back(ConstTensor(managedCellToInputWeights.GetTensorInfo(), + managedCellToInputWeights.Map())); } } if (m_PeepholeParameters.m_CellToForgetWeights != nullptr) { - constTensors.emplace_back(ConstTensor(m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo(), - m_PeepholeParameters.m_CellToForgetWeights->Map(true))); + constTensors.emplace_back(ConstTensor(managedCellToForgetWeights.GetTensorInfo(), + managedCellToForgetWeights.Map())); } if (m_PeepholeParameters.m_CellToOutputWeights != nullptr) { - constTensors.emplace_back(ConstTensor(m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo(), - m_PeepholeParameters.m_CellToOutputWeights->Map(true))); + constTensors.emplace_back(ConstTensor(managedCellToOutputWeights.GetTensorInfo(), + managedCellToOutputWeights.Map())); } } @@ -585,13 +645,13 @@ void LstmLayer::ExecuteStrategy(IStrategy& strategy) const { if (m_ProjectionParameters.m_ProjectionWeights != nullptr) { - constTensors.emplace_back(ConstTensor(m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo(), - m_ProjectionParameters.m_ProjectionWeights->Map(true))); + constTensors.emplace_back(ConstTensor(managedProjectionWeights.GetTensorInfo(), + managedProjectionWeights.Map())); } if (m_ProjectionParameters.m_ProjectionBias != nullptr) { - constTensors.emplace_back(ConstTensor(m_ProjectionParameters.m_ProjectionBias->GetTensorInfo(), - m_ProjectionParameters.m_ProjectionBias->Map(true))); + constTensors.emplace_back(ConstTensor(managedProjectionBias.GetTensorInfo(), + managedProjectionBias.Map())); } } @@ -602,24 +662,24 @@ void LstmLayer::ExecuteStrategy(IStrategy& strategy) const { if (m_LayerNormParameters.m_InputLayerNormWeights != nullptr) { - constTensors.emplace_back(ConstTensor(m_LayerNormParameters.m_InputLayerNormWeights->GetTensorInfo(), - m_LayerNormParameters.m_InputLayerNormWeights->Map(true))); + constTensors.emplace_back(ConstTensor(managedInputLayerNormWeights.GetTensorInfo(), + managedInputLayerNormWeights.Map())); } } if (m_LayerNormParameters.m_ForgetLayerNormWeights != nullptr) { - constTensors.emplace_back(ConstTensor(m_LayerNormParameters.m_ForgetLayerNormWeights->GetTensorInfo(), - m_LayerNormParameters.m_ForgetLayerNormWeights->Map(true))); + constTensors.emplace_back(ConstTensor(managedForgetLayerNormWeights.GetTensorInfo(), + managedForgetLayerNormWeights.Map())); } if (m_LayerNormParameters.m_CellLayerNormWeights != nullptr) { - constTensors.emplace_back(ConstTensor(m_LayerNormParameters.m_CellLayerNormWeights->GetTensorInfo(), - m_LayerNormParameters.m_CellLayerNormWeights->Map(true))); + constTensors.emplace_back(ConstTensor(managedCellLayerNormWeights.GetTensorInfo(), + managedCellLayerNormWeights.Map())); } if (m_LayerNormParameters.m_OutputLayerNormWeights != nullptr) { - constTensors.emplace_back(ConstTensor(m_LayerNormParameters.m_OutputLayerNormWeights->GetTensorInfo(), - m_LayerNormParameters.m_OutputLayerNormWeights->Map(true))); + constTensors.emplace_back(ConstTensor(managedOutputLayerNormWeights.GetTensorInfo(), + managedOutputLayerNormWeights.Map())); } } |