diff options
Diffstat (limited to 'src/armnn/layers/QuantizedLstmLayer.cpp')
-rw-r--r-- | src/armnn/layers/QuantizedLstmLayer.cpp | 126 |
1 files changed, 78 insertions, 48 deletions
diff --git a/src/armnn/layers/QuantizedLstmLayer.cpp b/src/armnn/layers/QuantizedLstmLayer.cpp index a1ff985abe..4d0dab9505 100644 --- a/src/armnn/layers/QuantizedLstmLayer.cpp +++ b/src/armnn/layers/QuantizedLstmLayer.cpp @@ -173,12 +173,27 @@ void QuantizedLstmLayer::Accept(ILayerVisitor& visitor) const { QuantizedLstmInputParams inputParams; + ManagedConstTensorHandle managedInputToInputWeights(m_QuantizedLstmParameters.m_InputToInputWeights); + ManagedConstTensorHandle managedInputToForgetWeights(m_QuantizedLstmParameters.m_InputToForgetWeights); + ManagedConstTensorHandle managedInputToCellWeights(m_QuantizedLstmParameters.m_InputToCellWeights); + ManagedConstTensorHandle managedInputToOutputWeights(m_QuantizedLstmParameters.m_InputToOutputWeights); + + ManagedConstTensorHandle managedRecurrentToInputWeights(m_QuantizedLstmParameters.m_RecurrentToInputWeights); + ManagedConstTensorHandle managedRecurrentToForgetWeights(m_QuantizedLstmParameters.m_RecurrentToForgetWeights); + ManagedConstTensorHandle managedRecurrentToCellWeights(m_QuantizedLstmParameters.m_RecurrentToCellWeights); + ManagedConstTensorHandle managedRecurrentToOutputWeights(m_QuantizedLstmParameters.m_RecurrentToOutputWeights); + + ManagedConstTensorHandle managedInputGateBias(m_QuantizedLstmParameters.m_InputGateBias); + ManagedConstTensorHandle managedForgetGateBias(m_QuantizedLstmParameters.m_ForgetGateBias); + ManagedConstTensorHandle managedCellBias(m_QuantizedLstmParameters.m_CellBias); + ManagedConstTensorHandle managedOutputGateBias(m_QuantizedLstmParameters.m_OutputGateBias); + // InputToX weight tensors ConstTensor inputToInputWeightsTensor; if (m_QuantizedLstmParameters.m_InputToInputWeights != nullptr) { - ConstTensor inputToInputWeightsTensorCopy(m_QuantizedLstmParameters.m_InputToInputWeights->GetTensorInfo(), - m_QuantizedLstmParameters.m_InputToInputWeights->Map(true)); + ConstTensor inputToInputWeightsTensorCopy(managedInputToInputWeights.GetTensorInfo(), + managedInputToInputWeights.Map()); inputToInputWeightsTensor = inputToInputWeightsTensorCopy; inputParams.m_InputToInputWeights = &inputToInputWeightsTensor; } @@ -186,8 +201,8 @@ void QuantizedLstmLayer::Accept(ILayerVisitor& visitor) const ConstTensor inputToForgetWeightsTensor; if (m_QuantizedLstmParameters.m_InputToForgetWeights != nullptr) { - ConstTensor inputToForgetWeightsTensorCopy(m_QuantizedLstmParameters.m_InputToForgetWeights->GetTensorInfo(), - m_QuantizedLstmParameters.m_InputToForgetWeights->Map(true)); + ConstTensor inputToForgetWeightsTensorCopy(managedInputToForgetWeights.GetTensorInfo(), + managedInputToForgetWeights.Map()); inputToForgetWeightsTensor = inputToForgetWeightsTensorCopy; inputParams.m_InputToForgetWeights = &inputToForgetWeightsTensor; } @@ -195,8 +210,8 @@ void QuantizedLstmLayer::Accept(ILayerVisitor& visitor) const ConstTensor inputToCellWeightsTensor; if (m_QuantizedLstmParameters.m_InputToCellWeights != nullptr) { - ConstTensor inputToCellWeightsTensorCopy(m_QuantizedLstmParameters.m_InputToCellWeights->GetTensorInfo(), - m_QuantizedLstmParameters.m_InputToCellWeights->Map(true)); + ConstTensor inputToCellWeightsTensorCopy(managedInputToCellWeights.GetTensorInfo(), + managedInputToCellWeights.Map()); inputToCellWeightsTensor = inputToCellWeightsTensorCopy; inputParams.m_InputToCellWeights = &inputToCellWeightsTensor; } @@ -204,8 +219,8 @@ void QuantizedLstmLayer::Accept(ILayerVisitor& visitor) const ConstTensor inputToOutputWeightsTensor; if (m_QuantizedLstmParameters.m_InputToOutputWeights != nullptr) { - ConstTensor inputToOutputWeightsTensorCopy(m_QuantizedLstmParameters.m_InputToOutputWeights->GetTensorInfo(), - m_QuantizedLstmParameters.m_InputToOutputWeights->Map(true)); + ConstTensor inputToOutputWeightsTensorCopy(managedInputToOutputWeights.GetTensorInfo(), + managedInputToOutputWeights.Map()); inputToOutputWeightsTensor = inputToOutputWeightsTensorCopy; inputParams.m_InputToOutputWeights = &inputToOutputWeightsTensor; } @@ -215,8 +230,8 @@ void QuantizedLstmLayer::Accept(ILayerVisitor& visitor) const if (m_QuantizedLstmParameters.m_RecurrentToInputWeights != nullptr) { ConstTensor recurrentToInputWeightsTensorCopy( - m_QuantizedLstmParameters.m_RecurrentToInputWeights->GetTensorInfo(), - m_QuantizedLstmParameters.m_RecurrentToInputWeights->Map(true)); + managedRecurrentToInputWeights.GetTensorInfo(), + managedRecurrentToInputWeights.Map()); recurrentToInputWeightsTensor = recurrentToInputWeightsTensorCopy; inputParams.m_RecurrentToInputWeights = &recurrentToInputWeightsTensor; } @@ -225,8 +240,8 @@ void QuantizedLstmLayer::Accept(ILayerVisitor& visitor) const if (m_QuantizedLstmParameters.m_RecurrentToForgetWeights != nullptr) { ConstTensor recurrentToForgetWeightsTensorCopy( - m_QuantizedLstmParameters.m_RecurrentToForgetWeights->GetTensorInfo(), - m_QuantizedLstmParameters.m_RecurrentToForgetWeights->Map(true)); + managedRecurrentToForgetWeights.GetTensorInfo(), + managedRecurrentToForgetWeights.Map()); recurrentToForgetWeightsTensor = recurrentToForgetWeightsTensorCopy; inputParams.m_RecurrentToForgetWeights = &recurrentToForgetWeightsTensor; } @@ -235,8 +250,8 @@ void QuantizedLstmLayer::Accept(ILayerVisitor& visitor) const if (m_QuantizedLstmParameters.m_RecurrentToCellWeights != nullptr) { ConstTensor recurrentToCellWeightsTensorCopy( - m_QuantizedLstmParameters.m_RecurrentToCellWeights->GetTensorInfo(), - m_QuantizedLstmParameters.m_RecurrentToCellWeights->Map(true)); + managedRecurrentToCellWeights.GetTensorInfo(), + managedRecurrentToCellWeights.Map()); recurrentToCellWeightsTensor = recurrentToCellWeightsTensorCopy; inputParams.m_RecurrentToCellWeights = &recurrentToCellWeightsTensor; } @@ -245,8 +260,8 @@ void QuantizedLstmLayer::Accept(ILayerVisitor& visitor) const if (m_QuantizedLstmParameters.m_RecurrentToOutputWeights != nullptr) { ConstTensor recurrentToOutputWeightsTensorCopy( - m_QuantizedLstmParameters.m_RecurrentToOutputWeights->GetTensorInfo(), - m_QuantizedLstmParameters.m_RecurrentToOutputWeights->Map(true)); + managedRecurrentToOutputWeights.GetTensorInfo(), + managedRecurrentToOutputWeights.Map()); recurrentToOutputWeightsTensor = recurrentToOutputWeightsTensorCopy; inputParams.m_RecurrentToOutputWeights = &recurrentToOutputWeightsTensor; } @@ -255,8 +270,8 @@ void QuantizedLstmLayer::Accept(ILayerVisitor& visitor) const ConstTensor inputGateBiasTensor; if (m_QuantizedLstmParameters.m_InputGateBias != nullptr) { - ConstTensor inputGateBiasTensorCopy(m_QuantizedLstmParameters.m_InputGateBias->GetTensorInfo(), - m_QuantizedLstmParameters.m_InputGateBias->Map(true)); + ConstTensor inputGateBiasTensorCopy(managedInputGateBias.GetTensorInfo(), + managedInputGateBias.Map()); inputGateBiasTensor = inputGateBiasTensorCopy; inputParams.m_InputGateBias = &inputGateBiasTensor; } @@ -264,8 +279,8 @@ void QuantizedLstmLayer::Accept(ILayerVisitor& visitor) const ConstTensor forgetGateBiasTensor; if (m_QuantizedLstmParameters.m_ForgetGateBias != nullptr) { - ConstTensor forgetGateBiasTensorCopy(m_QuantizedLstmParameters.m_ForgetGateBias->GetTensorInfo(), - m_QuantizedLstmParameters.m_ForgetGateBias->Map(true)); + ConstTensor forgetGateBiasTensorCopy(managedForgetGateBias.GetTensorInfo(), + managedForgetGateBias.Map()); forgetGateBiasTensor = forgetGateBiasTensorCopy; inputParams.m_ForgetGateBias = &forgetGateBiasTensor; } @@ -273,8 +288,8 @@ void QuantizedLstmLayer::Accept(ILayerVisitor& visitor) const ConstTensor cellBiasTensor; if (m_QuantizedLstmParameters.m_CellBias != nullptr) { - ConstTensor cellBiasTensorCopy(m_QuantizedLstmParameters.m_CellBias->GetTensorInfo(), - m_QuantizedLstmParameters.m_CellBias->Map(true)); + ConstTensor cellBiasTensorCopy(managedCellBias.GetTensorInfo(), + managedCellBias.Map()); cellBiasTensor = cellBiasTensorCopy; inputParams.m_CellBias = &cellBiasTensor; } @@ -282,8 +297,8 @@ void QuantizedLstmLayer::Accept(ILayerVisitor& visitor) const ConstTensor outputGateBiasTensor; if (m_QuantizedLstmParameters.m_OutputGateBias != nullptr) { - ConstTensor outputGateBiasCopy(m_QuantizedLstmParameters.m_OutputGateBias->GetTensorInfo(), - m_QuantizedLstmParameters.m_OutputGateBias->Map(true)); + ConstTensor outputGateBiasCopy(managedOutputGateBias.GetTensorInfo(), + managedOutputGateBias.Map()); outputGateBiasTensor = outputGateBiasCopy; inputParams.m_OutputGateBias = &outputGateBiasTensor; } @@ -295,83 +310,98 @@ void QuantizedLstmLayer::ExecuteStrategy(IStrategy& strategy) const { std::vector<ConstTensor> constTensors; + ManagedConstTensorHandle managedInputToInputWeights(m_QuantizedLstmParameters.m_InputToInputWeights); + ManagedConstTensorHandle managedInputToForgetWeights(m_QuantizedLstmParameters.m_InputToForgetWeights); + ManagedConstTensorHandle managedInputToCellWeights(m_QuantizedLstmParameters.m_InputToCellWeights); + ManagedConstTensorHandle managedInputToOutputWeights(m_QuantizedLstmParameters.m_InputToOutputWeights); + + ManagedConstTensorHandle managedRecurrentToInputWeights(m_QuantizedLstmParameters.m_RecurrentToInputWeights); + ManagedConstTensorHandle managedRecurrentToForgetWeights(m_QuantizedLstmParameters.m_RecurrentToForgetWeights); + ManagedConstTensorHandle managedRecurrentToCellWeights(m_QuantizedLstmParameters.m_RecurrentToCellWeights); + ManagedConstTensorHandle managedRecurrentToOutputWeights(m_QuantizedLstmParameters.m_RecurrentToOutputWeights); + + ManagedConstTensorHandle managedInputGateBias(m_QuantizedLstmParameters.m_InputGateBias); + ManagedConstTensorHandle managedForgetGateBias(m_QuantizedLstmParameters.m_ForgetGateBias); + ManagedConstTensorHandle managedCellBias(m_QuantizedLstmParameters.m_CellBias); + ManagedConstTensorHandle managedOutputGateBias(m_QuantizedLstmParameters.m_OutputGateBias); + // InputToX weight tensors if (m_QuantizedLstmParameters.m_InputToInputWeights != nullptr) { - constTensors.emplace_back(ConstTensor(m_QuantizedLstmParameters.m_InputToInputWeights->GetTensorInfo(), - m_QuantizedLstmParameters.m_InputToInputWeights->Map(true))); + constTensors.emplace_back(ConstTensor(managedInputToInputWeights.GetTensorInfo(), + managedInputToInputWeights.Map())); } if (m_QuantizedLstmParameters.m_InputToForgetWeights != nullptr) { - constTensors.emplace_back(ConstTensor(m_QuantizedLstmParameters.m_InputToForgetWeights->GetTensorInfo(), - m_QuantizedLstmParameters.m_InputToForgetWeights->Map(true))); + constTensors.emplace_back(ConstTensor(managedInputToForgetWeights.GetTensorInfo(), + managedInputToForgetWeights.Map())); } if (m_QuantizedLstmParameters.m_InputToCellWeights != nullptr) { - constTensors.emplace_back(ConstTensor(m_QuantizedLstmParameters.m_InputToCellWeights->GetTensorInfo(), - m_QuantizedLstmParameters.m_InputToCellWeights->Map(true))); + constTensors.emplace_back(ConstTensor(managedInputToCellWeights.GetTensorInfo(), + managedInputToCellWeights.Map())); } if (m_QuantizedLstmParameters.m_InputToOutputWeights != nullptr) { - constTensors.emplace_back(ConstTensor(m_QuantizedLstmParameters.m_InputToOutputWeights->GetTensorInfo(), - m_QuantizedLstmParameters.m_InputToOutputWeights->Map(true))); + constTensors.emplace_back(ConstTensor(managedInputToOutputWeights.GetTensorInfo(), + managedInputToOutputWeights.Map())); } // RecurrentToX weight tensors if (m_QuantizedLstmParameters.m_RecurrentToInputWeights != nullptr) { constTensors.emplace_back(ConstTensor( - m_QuantizedLstmParameters.m_RecurrentToInputWeights->GetTensorInfo(), - m_QuantizedLstmParameters.m_RecurrentToInputWeights->Map(true))); + managedRecurrentToInputWeights.GetTensorInfo(), + managedRecurrentToInputWeights.Map())); } if (m_QuantizedLstmParameters.m_RecurrentToForgetWeights != nullptr) { constTensors.emplace_back(ConstTensor( - m_QuantizedLstmParameters.m_RecurrentToForgetWeights->GetTensorInfo(), - m_QuantizedLstmParameters.m_RecurrentToForgetWeights->Map(true))); + managedRecurrentToForgetWeights.GetTensorInfo(), + managedRecurrentToForgetWeights.Map())); } if (m_QuantizedLstmParameters.m_RecurrentToCellWeights != nullptr) { constTensors.emplace_back(ConstTensor( - m_QuantizedLstmParameters.m_RecurrentToCellWeights->GetTensorInfo(), - m_QuantizedLstmParameters.m_RecurrentToCellWeights->Map(true))); + managedRecurrentToCellWeights.GetTensorInfo(), + managedRecurrentToCellWeights.Map())); } if (m_QuantizedLstmParameters.m_RecurrentToOutputWeights != nullptr) { constTensors.emplace_back(ConstTensor( - m_QuantizedLstmParameters.m_RecurrentToOutputWeights->GetTensorInfo(), - m_QuantizedLstmParameters.m_RecurrentToOutputWeights->Map(true))); + managedRecurrentToOutputWeights.GetTensorInfo(), + managedRecurrentToOutputWeights.Map())); } // Bias tensors if (m_QuantizedLstmParameters.m_InputGateBias != nullptr) { - constTensors.emplace_back(ConstTensor(m_QuantizedLstmParameters.m_InputGateBias->GetTensorInfo(), - m_QuantizedLstmParameters.m_InputGateBias->Map(true))); + constTensors.emplace_back(ConstTensor(managedInputGateBias.GetTensorInfo(), + managedInputGateBias.Map())); } if (m_QuantizedLstmParameters.m_ForgetGateBias != nullptr) { - constTensors.emplace_back(ConstTensor(m_QuantizedLstmParameters.m_ForgetGateBias->GetTensorInfo(), - m_QuantizedLstmParameters.m_ForgetGateBias->Map(true))); + constTensors.emplace_back(ConstTensor(managedForgetGateBias.GetTensorInfo(), + managedForgetGateBias.Map())); } if (m_QuantizedLstmParameters.m_CellBias != nullptr) { - constTensors.emplace_back(ConstTensor(m_QuantizedLstmParameters.m_CellBias->GetTensorInfo(), - m_QuantizedLstmParameters.m_CellBias->Map(true))); + constTensors.emplace_back(ConstTensor(managedCellBias.GetTensorInfo(), + managedCellBias.Map())); } if (m_QuantizedLstmParameters.m_OutputGateBias != nullptr) { - constTensors.emplace_back(ConstTensor(m_QuantizedLstmParameters.m_OutputGateBias->GetTensorInfo(), - m_QuantizedLstmParameters.m_OutputGateBias->Map(true))); + constTensors.emplace_back(ConstTensor(managedOutputGateBias.GetTensorInfo(), + managedOutputGateBias.Map())); } |