aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/layers/QuantizedLstmLayer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/layers/QuantizedLstmLayer.cpp')
-rw-r--r--src/armnn/layers/QuantizedLstmLayer.cpp126
1 files changed, 78 insertions, 48 deletions
diff --git a/src/armnn/layers/QuantizedLstmLayer.cpp b/src/armnn/layers/QuantizedLstmLayer.cpp
index a1ff985abe..4d0dab9505 100644
--- a/src/armnn/layers/QuantizedLstmLayer.cpp
+++ b/src/armnn/layers/QuantizedLstmLayer.cpp
@@ -173,12 +173,27 @@ void QuantizedLstmLayer::Accept(ILayerVisitor& visitor) const
{
QuantizedLstmInputParams inputParams;
+ ManagedConstTensorHandle managedInputToInputWeights(m_QuantizedLstmParameters.m_InputToInputWeights);
+ ManagedConstTensorHandle managedInputToForgetWeights(m_QuantizedLstmParameters.m_InputToForgetWeights);
+ ManagedConstTensorHandle managedInputToCellWeights(m_QuantizedLstmParameters.m_InputToCellWeights);
+ ManagedConstTensorHandle managedInputToOutputWeights(m_QuantizedLstmParameters.m_InputToOutputWeights);
+
+ ManagedConstTensorHandle managedRecurrentToInputWeights(m_QuantizedLstmParameters.m_RecurrentToInputWeights);
+ ManagedConstTensorHandle managedRecurrentToForgetWeights(m_QuantizedLstmParameters.m_RecurrentToForgetWeights);
+ ManagedConstTensorHandle managedRecurrentToCellWeights(m_QuantizedLstmParameters.m_RecurrentToCellWeights);
+ ManagedConstTensorHandle managedRecurrentToOutputWeights(m_QuantizedLstmParameters.m_RecurrentToOutputWeights);
+
+ ManagedConstTensorHandle managedInputGateBias(m_QuantizedLstmParameters.m_InputGateBias);
+ ManagedConstTensorHandle managedForgetGateBias(m_QuantizedLstmParameters.m_ForgetGateBias);
+ ManagedConstTensorHandle managedCellBias(m_QuantizedLstmParameters.m_CellBias);
+ ManagedConstTensorHandle managedOutputGateBias(m_QuantizedLstmParameters.m_OutputGateBias);
+
// InputToX weight tensors
ConstTensor inputToInputWeightsTensor;
if (m_QuantizedLstmParameters.m_InputToInputWeights != nullptr)
{
- ConstTensor inputToInputWeightsTensorCopy(m_QuantizedLstmParameters.m_InputToInputWeights->GetTensorInfo(),
- m_QuantizedLstmParameters.m_InputToInputWeights->Map(true));
+ ConstTensor inputToInputWeightsTensorCopy(managedInputToInputWeights.GetTensorInfo(),
+ managedInputToInputWeights.Map());
inputToInputWeightsTensor = inputToInputWeightsTensorCopy;
inputParams.m_InputToInputWeights = &inputToInputWeightsTensor;
}
@@ -186,8 +201,8 @@ void QuantizedLstmLayer::Accept(ILayerVisitor& visitor) const
ConstTensor inputToForgetWeightsTensor;
if (m_QuantizedLstmParameters.m_InputToForgetWeights != nullptr)
{
- ConstTensor inputToForgetWeightsTensorCopy(m_QuantizedLstmParameters.m_InputToForgetWeights->GetTensorInfo(),
- m_QuantizedLstmParameters.m_InputToForgetWeights->Map(true));
+ ConstTensor inputToForgetWeightsTensorCopy(managedInputToForgetWeights.GetTensorInfo(),
+ managedInputToForgetWeights.Map());
inputToForgetWeightsTensor = inputToForgetWeightsTensorCopy;
inputParams.m_InputToForgetWeights = &inputToForgetWeightsTensor;
}
@@ -195,8 +210,8 @@ void QuantizedLstmLayer::Accept(ILayerVisitor& visitor) const
ConstTensor inputToCellWeightsTensor;
if (m_QuantizedLstmParameters.m_InputToCellWeights != nullptr)
{
- ConstTensor inputToCellWeightsTensorCopy(m_QuantizedLstmParameters.m_InputToCellWeights->GetTensorInfo(),
- m_QuantizedLstmParameters.m_InputToCellWeights->Map(true));
+ ConstTensor inputToCellWeightsTensorCopy(managedInputToCellWeights.GetTensorInfo(),
+ managedInputToCellWeights.Map());
inputToCellWeightsTensor = inputToCellWeightsTensorCopy;
inputParams.m_InputToCellWeights = &inputToCellWeightsTensor;
}
@@ -204,8 +219,8 @@ void QuantizedLstmLayer::Accept(ILayerVisitor& visitor) const
ConstTensor inputToOutputWeightsTensor;
if (m_QuantizedLstmParameters.m_InputToOutputWeights != nullptr)
{
- ConstTensor inputToOutputWeightsTensorCopy(m_QuantizedLstmParameters.m_InputToOutputWeights->GetTensorInfo(),
- m_QuantizedLstmParameters.m_InputToOutputWeights->Map(true));
+ ConstTensor inputToOutputWeightsTensorCopy(managedInputToOutputWeights.GetTensorInfo(),
+ managedInputToOutputWeights.Map());
inputToOutputWeightsTensor = inputToOutputWeightsTensorCopy;
inputParams.m_InputToOutputWeights = &inputToOutputWeightsTensor;
}
@@ -215,8 +230,8 @@ void QuantizedLstmLayer::Accept(ILayerVisitor& visitor) const
if (m_QuantizedLstmParameters.m_RecurrentToInputWeights != nullptr)
{
ConstTensor recurrentToInputWeightsTensorCopy(
- m_QuantizedLstmParameters.m_RecurrentToInputWeights->GetTensorInfo(),
- m_QuantizedLstmParameters.m_RecurrentToInputWeights->Map(true));
+ managedRecurrentToInputWeights.GetTensorInfo(),
+ managedRecurrentToInputWeights.Map());
recurrentToInputWeightsTensor = recurrentToInputWeightsTensorCopy;
inputParams.m_RecurrentToInputWeights = &recurrentToInputWeightsTensor;
}
@@ -225,8 +240,8 @@ void QuantizedLstmLayer::Accept(ILayerVisitor& visitor) const
if (m_QuantizedLstmParameters.m_RecurrentToForgetWeights != nullptr)
{
ConstTensor recurrentToForgetWeightsTensorCopy(
- m_QuantizedLstmParameters.m_RecurrentToForgetWeights->GetTensorInfo(),
- m_QuantizedLstmParameters.m_RecurrentToForgetWeights->Map(true));
+ managedRecurrentToForgetWeights.GetTensorInfo(),
+ managedRecurrentToForgetWeights.Map());
recurrentToForgetWeightsTensor = recurrentToForgetWeightsTensorCopy;
inputParams.m_RecurrentToForgetWeights = &recurrentToForgetWeightsTensor;
}
@@ -235,8 +250,8 @@ void QuantizedLstmLayer::Accept(ILayerVisitor& visitor) const
if (m_QuantizedLstmParameters.m_RecurrentToCellWeights != nullptr)
{
ConstTensor recurrentToCellWeightsTensorCopy(
- m_QuantizedLstmParameters.m_RecurrentToCellWeights->GetTensorInfo(),
- m_QuantizedLstmParameters.m_RecurrentToCellWeights->Map(true));
+ managedRecurrentToCellWeights.GetTensorInfo(),
+ managedRecurrentToCellWeights.Map());
recurrentToCellWeightsTensor = recurrentToCellWeightsTensorCopy;
inputParams.m_RecurrentToCellWeights = &recurrentToCellWeightsTensor;
}
@@ -245,8 +260,8 @@ void QuantizedLstmLayer::Accept(ILayerVisitor& visitor) const
if (m_QuantizedLstmParameters.m_RecurrentToOutputWeights != nullptr)
{
ConstTensor recurrentToOutputWeightsTensorCopy(
- m_QuantizedLstmParameters.m_RecurrentToOutputWeights->GetTensorInfo(),
- m_QuantizedLstmParameters.m_RecurrentToOutputWeights->Map(true));
+ managedRecurrentToOutputWeights.GetTensorInfo(),
+ managedRecurrentToOutputWeights.Map());
recurrentToOutputWeightsTensor = recurrentToOutputWeightsTensorCopy;
inputParams.m_RecurrentToOutputWeights = &recurrentToOutputWeightsTensor;
}
@@ -255,8 +270,8 @@ void QuantizedLstmLayer::Accept(ILayerVisitor& visitor) const
ConstTensor inputGateBiasTensor;
if (m_QuantizedLstmParameters.m_InputGateBias != nullptr)
{
- ConstTensor inputGateBiasTensorCopy(m_QuantizedLstmParameters.m_InputGateBias->GetTensorInfo(),
- m_QuantizedLstmParameters.m_InputGateBias->Map(true));
+ ConstTensor inputGateBiasTensorCopy(managedInputGateBias.GetTensorInfo(),
+ managedInputGateBias.Map());
inputGateBiasTensor = inputGateBiasTensorCopy;
inputParams.m_InputGateBias = &inputGateBiasTensor;
}
@@ -264,8 +279,8 @@ void QuantizedLstmLayer::Accept(ILayerVisitor& visitor) const
ConstTensor forgetGateBiasTensor;
if (m_QuantizedLstmParameters.m_ForgetGateBias != nullptr)
{
- ConstTensor forgetGateBiasTensorCopy(m_QuantizedLstmParameters.m_ForgetGateBias->GetTensorInfo(),
- m_QuantizedLstmParameters.m_ForgetGateBias->Map(true));
+ ConstTensor forgetGateBiasTensorCopy(managedForgetGateBias.GetTensorInfo(),
+ managedForgetGateBias.Map());
forgetGateBiasTensor = forgetGateBiasTensorCopy;
inputParams.m_ForgetGateBias = &forgetGateBiasTensor;
}
@@ -273,8 +288,8 @@ void QuantizedLstmLayer::Accept(ILayerVisitor& visitor) const
ConstTensor cellBiasTensor;
if (m_QuantizedLstmParameters.m_CellBias != nullptr)
{
- ConstTensor cellBiasTensorCopy(m_QuantizedLstmParameters.m_CellBias->GetTensorInfo(),
- m_QuantizedLstmParameters.m_CellBias->Map(true));
+ ConstTensor cellBiasTensorCopy(managedCellBias.GetTensorInfo(),
+ managedCellBias.Map());
cellBiasTensor = cellBiasTensorCopy;
inputParams.m_CellBias = &cellBiasTensor;
}
@@ -282,8 +297,8 @@ void QuantizedLstmLayer::Accept(ILayerVisitor& visitor) const
ConstTensor outputGateBiasTensor;
if (m_QuantizedLstmParameters.m_OutputGateBias != nullptr)
{
- ConstTensor outputGateBiasCopy(m_QuantizedLstmParameters.m_OutputGateBias->GetTensorInfo(),
- m_QuantizedLstmParameters.m_OutputGateBias->Map(true));
+ ConstTensor outputGateBiasCopy(managedOutputGateBias.GetTensorInfo(),
+ managedOutputGateBias.Map());
outputGateBiasTensor = outputGateBiasCopy;
inputParams.m_OutputGateBias = &outputGateBiasTensor;
}
@@ -295,83 +310,98 @@ void QuantizedLstmLayer::ExecuteStrategy(IStrategy& strategy) const
{
std::vector<ConstTensor> constTensors;
+ ManagedConstTensorHandle managedInputToInputWeights(m_QuantizedLstmParameters.m_InputToInputWeights);
+ ManagedConstTensorHandle managedInputToForgetWeights(m_QuantizedLstmParameters.m_InputToForgetWeights);
+ ManagedConstTensorHandle managedInputToCellWeights(m_QuantizedLstmParameters.m_InputToCellWeights);
+ ManagedConstTensorHandle managedInputToOutputWeights(m_QuantizedLstmParameters.m_InputToOutputWeights);
+
+ ManagedConstTensorHandle managedRecurrentToInputWeights(m_QuantizedLstmParameters.m_RecurrentToInputWeights);
+ ManagedConstTensorHandle managedRecurrentToForgetWeights(m_QuantizedLstmParameters.m_RecurrentToForgetWeights);
+ ManagedConstTensorHandle managedRecurrentToCellWeights(m_QuantizedLstmParameters.m_RecurrentToCellWeights);
+ ManagedConstTensorHandle managedRecurrentToOutputWeights(m_QuantizedLstmParameters.m_RecurrentToOutputWeights);
+
+ ManagedConstTensorHandle managedInputGateBias(m_QuantizedLstmParameters.m_InputGateBias);
+ ManagedConstTensorHandle managedForgetGateBias(m_QuantizedLstmParameters.m_ForgetGateBias);
+ ManagedConstTensorHandle managedCellBias(m_QuantizedLstmParameters.m_CellBias);
+ ManagedConstTensorHandle managedOutputGateBias(m_QuantizedLstmParameters.m_OutputGateBias);
+
// InputToX weight tensors
if (m_QuantizedLstmParameters.m_InputToInputWeights != nullptr)
{
- constTensors.emplace_back(ConstTensor(m_QuantizedLstmParameters.m_InputToInputWeights->GetTensorInfo(),
- m_QuantizedLstmParameters.m_InputToInputWeights->Map(true)));
+ constTensors.emplace_back(ConstTensor(managedInputToInputWeights.GetTensorInfo(),
+ managedInputToInputWeights.Map()));
}
if (m_QuantizedLstmParameters.m_InputToForgetWeights != nullptr)
{
- constTensors.emplace_back(ConstTensor(m_QuantizedLstmParameters.m_InputToForgetWeights->GetTensorInfo(),
- m_QuantizedLstmParameters.m_InputToForgetWeights->Map(true)));
+ constTensors.emplace_back(ConstTensor(managedInputToForgetWeights.GetTensorInfo(),
+ managedInputToForgetWeights.Map()));
}
if (m_QuantizedLstmParameters.m_InputToCellWeights != nullptr)
{
- constTensors.emplace_back(ConstTensor(m_QuantizedLstmParameters.m_InputToCellWeights->GetTensorInfo(),
- m_QuantizedLstmParameters.m_InputToCellWeights->Map(true)));
+ constTensors.emplace_back(ConstTensor(managedInputToCellWeights.GetTensorInfo(),
+ managedInputToCellWeights.Map()));
}
if (m_QuantizedLstmParameters.m_InputToOutputWeights != nullptr)
{
- constTensors.emplace_back(ConstTensor(m_QuantizedLstmParameters.m_InputToOutputWeights->GetTensorInfo(),
- m_QuantizedLstmParameters.m_InputToOutputWeights->Map(true)));
+ constTensors.emplace_back(ConstTensor(managedInputToOutputWeights.GetTensorInfo(),
+ managedInputToOutputWeights.Map()));
}
// RecurrentToX weight tensors
if (m_QuantizedLstmParameters.m_RecurrentToInputWeights != nullptr)
{
constTensors.emplace_back(ConstTensor(
- m_QuantizedLstmParameters.m_RecurrentToInputWeights->GetTensorInfo(),
- m_QuantizedLstmParameters.m_RecurrentToInputWeights->Map(true)));
+ managedRecurrentToInputWeights.GetTensorInfo(),
+ managedRecurrentToInputWeights.Map()));
}
if (m_QuantizedLstmParameters.m_RecurrentToForgetWeights != nullptr)
{
constTensors.emplace_back(ConstTensor(
- m_QuantizedLstmParameters.m_RecurrentToForgetWeights->GetTensorInfo(),
- m_QuantizedLstmParameters.m_RecurrentToForgetWeights->Map(true)));
+ managedRecurrentToForgetWeights.GetTensorInfo(),
+ managedRecurrentToForgetWeights.Map()));
}
if (m_QuantizedLstmParameters.m_RecurrentToCellWeights != nullptr)
{
constTensors.emplace_back(ConstTensor(
- m_QuantizedLstmParameters.m_RecurrentToCellWeights->GetTensorInfo(),
- m_QuantizedLstmParameters.m_RecurrentToCellWeights->Map(true)));
+ managedRecurrentToCellWeights.GetTensorInfo(),
+ managedRecurrentToCellWeights.Map()));
}
if (m_QuantizedLstmParameters.m_RecurrentToOutputWeights != nullptr)
{
constTensors.emplace_back(ConstTensor(
- m_QuantizedLstmParameters.m_RecurrentToOutputWeights->GetTensorInfo(),
- m_QuantizedLstmParameters.m_RecurrentToOutputWeights->Map(true)));
+ managedRecurrentToOutputWeights.GetTensorInfo(),
+ managedRecurrentToOutputWeights.Map()));
}
// Bias tensors
if (m_QuantizedLstmParameters.m_InputGateBias != nullptr)
{
- constTensors.emplace_back(ConstTensor(m_QuantizedLstmParameters.m_InputGateBias->GetTensorInfo(),
- m_QuantizedLstmParameters.m_InputGateBias->Map(true)));
+ constTensors.emplace_back(ConstTensor(managedInputGateBias.GetTensorInfo(),
+ managedInputGateBias.Map()));
}
if (m_QuantizedLstmParameters.m_ForgetGateBias != nullptr)
{
- constTensors.emplace_back(ConstTensor(m_QuantizedLstmParameters.m_ForgetGateBias->GetTensorInfo(),
- m_QuantizedLstmParameters.m_ForgetGateBias->Map(true)));
+ constTensors.emplace_back(ConstTensor(managedForgetGateBias.GetTensorInfo(),
+ managedForgetGateBias.Map()));
}
if (m_QuantizedLstmParameters.m_CellBias != nullptr)
{
- constTensors.emplace_back(ConstTensor(m_QuantizedLstmParameters.m_CellBias->GetTensorInfo(),
- m_QuantizedLstmParameters.m_CellBias->Map(true)));
+ constTensors.emplace_back(ConstTensor(managedCellBias.GetTensorInfo(),
+ managedCellBias.Map()));
}
if (m_QuantizedLstmParameters.m_OutputGateBias != nullptr)
{
- constTensors.emplace_back(ConstTensor(m_QuantizedLstmParameters.m_OutputGateBias->GetTensorInfo(),
- m_QuantizedLstmParameters.m_OutputGateBias->Map(true)));
+ constTensors.emplace_back(ConstTensor(managedOutputGateBias.GetTensorInfo(),
+ managedOutputGateBias.Map()));
}