aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/layers/QLstmLayer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/layers/QLstmLayer.cpp')
-rw-r--r--src/armnn/layers/QLstmLayer.cpp226
1 files changed, 142 insertions, 84 deletions
diff --git a/src/armnn/layers/QLstmLayer.cpp b/src/armnn/layers/QLstmLayer.cpp
index 16aa718eb9..72b020f109 100644
--- a/src/armnn/layers/QLstmLayer.cpp
+++ b/src/armnn/layers/QLstmLayer.cpp
@@ -305,12 +305,41 @@ Layer::ConstantTensors QLstmLayer::GetConstantTensorsByRef()
void QLstmLayer::Accept(ILayerVisitor& visitor) const
{
LstmInputParams inputParams;
+ ManagedConstTensorHandle managedInputToForgetWeights(m_BasicParameters.m_InputToForgetWeights);
+ ManagedConstTensorHandle managedInputToCellWeights(m_BasicParameters.m_InputToCellWeights);
+ ManagedConstTensorHandle managedInputToOutputWeights(m_BasicParameters.m_InputToOutputWeights);
+ ManagedConstTensorHandle managedRecurrentToForgetWeights(m_BasicParameters.m_RecurrentToForgetWeights);
+ ManagedConstTensorHandle managedRecurrentToCellWeights(m_BasicParameters.m_RecurrentToCellWeights);
+ ManagedConstTensorHandle managedRecurrentToOutputWeights(m_BasicParameters.m_RecurrentToOutputWeights);
+ ManagedConstTensorHandle managedForgetGateBias(m_BasicParameters.m_ForgetGateBias);
+ ManagedConstTensorHandle managedCellBias(m_BasicParameters.m_CellBias);
+ ManagedConstTensorHandle managedOutputGateBias(m_BasicParameters.m_OutputGateBias);
+
+ // Cifg parameters
+ ManagedConstTensorHandle managedInputToInputWeights(m_CifgParameters.m_InputToInputWeights);
+ ManagedConstTensorHandle managedRecurrentToInputWeights(m_CifgParameters.m_RecurrentToInputWeights);
+ ManagedConstTensorHandle managedInputGateBias(m_CifgParameters.m_InputGateBias);
+
+ // Projection parameters
+ ManagedConstTensorHandle managedProjectionWeights(m_ProjectionParameters.m_ProjectionWeights);
+ ManagedConstTensorHandle managedProjectionBias(m_ProjectionParameters.m_ProjectionBias);
+
+ // Peephole parameters
+ ManagedConstTensorHandle managedCellToInputWeights(m_PeepholeParameters.m_CellToInputWeights);
+ ManagedConstTensorHandle managedCellToForgetWeights(m_PeepholeParameters.m_CellToForgetWeights);
+ ManagedConstTensorHandle managedCellToOutputWeights(m_PeepholeParameters.m_CellToOutputWeights);
+
+ // Layer normalisation parameters
+ ManagedConstTensorHandle managedInputLayerNormWeights(m_LayerNormParameters.m_InputLayerNormWeights);
+ ManagedConstTensorHandle managedForgetLayerNormWeights(m_LayerNormParameters.m_ForgetLayerNormWeights);
+ ManagedConstTensorHandle managedCellLayerNormWeights(m_LayerNormParameters.m_CellLayerNormWeights);
+ ManagedConstTensorHandle managedOutputLayerNormWeights(m_LayerNormParameters.m_OutputLayerNormWeights);
ConstTensor inputToInputWeightsTensor;
if (m_CifgParameters.m_InputToInputWeights != nullptr)
{
- ConstTensor inputToInputWeightsTensorCopy(m_CifgParameters.m_InputToInputWeights->GetTensorInfo(),
- m_CifgParameters.m_InputToInputWeights->Map(true));
+ ConstTensor inputToInputWeightsTensorCopy(managedInputToInputWeights.GetTensorInfo(),
+ managedInputToInputWeights.Map());
inputToInputWeightsTensor = inputToInputWeightsTensorCopy;
inputParams.m_InputToInputWeights = &inputToInputWeightsTensor;
}
@@ -318,8 +347,8 @@ void QLstmLayer::Accept(ILayerVisitor& visitor) const
ConstTensor inputToForgetWeightsTensor;
if (m_BasicParameters.m_InputToForgetWeights != nullptr)
{
- ConstTensor inputToForgetWeightsTensorCopy(m_BasicParameters.m_InputToForgetWeights->GetTensorInfo(),
- m_BasicParameters.m_InputToForgetWeights->Map(true));
+ ConstTensor inputToForgetWeightsTensorCopy(managedInputToForgetWeights.GetTensorInfo(),
+ managedInputToForgetWeights.Map());
inputToForgetWeightsTensor = inputToForgetWeightsTensorCopy;
inputParams.m_InputToForgetWeights = &inputToForgetWeightsTensor;
}
@@ -327,8 +356,8 @@ void QLstmLayer::Accept(ILayerVisitor& visitor) const
ConstTensor inputToCellWeightsTensor;
if (m_BasicParameters.m_InputToCellWeights != nullptr)
{
- ConstTensor inputToCellWeightsTensorCopy(m_BasicParameters.m_InputToCellWeights->GetTensorInfo(),
- m_BasicParameters.m_InputToCellWeights->Map(true));
+ ConstTensor inputToCellWeightsTensorCopy(managedInputToCellWeights.GetTensorInfo(),
+ managedInputToCellWeights.Map());
inputToCellWeightsTensor = inputToCellWeightsTensorCopy;
inputParams.m_InputToCellWeights = &inputToCellWeightsTensor;
}
@@ -336,8 +365,8 @@ void QLstmLayer::Accept(ILayerVisitor& visitor) const
ConstTensor inputToOutputWeightsTensor;
if (m_BasicParameters.m_InputToOutputWeights != nullptr)
{
- ConstTensor inputToOutputWeightsTensorCopy(m_BasicParameters.m_InputToOutputWeights->GetTensorInfo(),
- m_BasicParameters.m_InputToOutputWeights->Map(true));
+ ConstTensor inputToOutputWeightsTensorCopy(managedInputToOutputWeights.GetTensorInfo(),
+ managedInputToOutputWeights.Map());
inputToOutputWeightsTensor = inputToOutputWeightsTensorCopy;
inputParams.m_InputToOutputWeights = &inputToOutputWeightsTensor;
}
@@ -346,8 +375,8 @@ void QLstmLayer::Accept(ILayerVisitor& visitor) const
if (m_CifgParameters.m_RecurrentToInputWeights != nullptr)
{
ConstTensor recurrentToInputWeightsTensorCopy(
- m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo(),
- m_CifgParameters.m_RecurrentToInputWeights->Map(true));
+ managedRecurrentToInputWeights.GetTensorInfo(),
+ managedRecurrentToInputWeights.Map());
recurrentToInputWeightsTensor = recurrentToInputWeightsTensorCopy;
inputParams.m_RecurrentToInputWeights = &recurrentToInputWeightsTensor;
}
@@ -356,8 +385,8 @@ void QLstmLayer::Accept(ILayerVisitor& visitor) const
if (m_BasicParameters.m_RecurrentToForgetWeights != nullptr)
{
ConstTensor recurrentToForgetWeightsTensorCopy(
- m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo(),
- m_BasicParameters.m_RecurrentToForgetWeights->Map(true));
+ managedRecurrentToForgetWeights.GetTensorInfo(),
+ managedRecurrentToForgetWeights.Map());
recurrentToForgetWeightsTensor = recurrentToForgetWeightsTensorCopy;
inputParams.m_RecurrentToForgetWeights = &recurrentToForgetWeightsTensor;
}
@@ -366,8 +395,8 @@ void QLstmLayer::Accept(ILayerVisitor& visitor) const
if (m_BasicParameters.m_RecurrentToCellWeights != nullptr)
{
ConstTensor recurrentToCellWeightsTensorCopy(
- m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo(),
- m_BasicParameters.m_RecurrentToCellWeights->Map(true));
+ managedRecurrentToCellWeights.GetTensorInfo(),
+ managedRecurrentToCellWeights.Map());
recurrentToCellWeightsTensor = recurrentToCellWeightsTensorCopy;
inputParams.m_RecurrentToCellWeights = &recurrentToCellWeightsTensor;
}
@@ -376,8 +405,8 @@ void QLstmLayer::Accept(ILayerVisitor& visitor) const
if (m_BasicParameters.m_RecurrentToOutputWeights != nullptr)
{
ConstTensor recurrentToOutputWeightsTensorCopy(
- m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo(),
- m_BasicParameters.m_RecurrentToOutputWeights->Map(true));
+ managedRecurrentToOutputWeights.GetTensorInfo(),
+ managedRecurrentToOutputWeights.Map());
recurrentToOutputWeightsTensor = recurrentToOutputWeightsTensorCopy;
inputParams.m_RecurrentToOutputWeights = &recurrentToOutputWeightsTensor;
}
@@ -385,8 +414,8 @@ void QLstmLayer::Accept(ILayerVisitor& visitor) const
ConstTensor cellToInputWeightsTensor;
if (m_PeepholeParameters.m_CellToInputWeights != nullptr)
{
- ConstTensor cellToInputWeightsTensorCopy(m_PeepholeParameters.m_CellToInputWeights->GetTensorInfo(),
- m_PeepholeParameters.m_CellToInputWeights->Map(true));
+ ConstTensor cellToInputWeightsTensorCopy(managedCellToInputWeights.GetTensorInfo(),
+ managedCellToInputWeights.Map());
cellToInputWeightsTensor = cellToInputWeightsTensorCopy;
inputParams.m_CellToInputWeights = &cellToInputWeightsTensor;
}
@@ -394,8 +423,8 @@ void QLstmLayer::Accept(ILayerVisitor& visitor) const
ConstTensor cellToForgetWeightsTensor;
if (m_PeepholeParameters.m_CellToForgetWeights != nullptr)
{
- ConstTensor cellToForgetWeightsTensorCopy(m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo(),
- m_PeepholeParameters.m_CellToForgetWeights->Map(true));
+ ConstTensor cellToForgetWeightsTensorCopy(managedCellToForgetWeights.GetTensorInfo(),
+ managedCellToForgetWeights.Map());
cellToForgetWeightsTensor = cellToForgetWeightsTensorCopy;
inputParams.m_CellToForgetWeights = &cellToForgetWeightsTensor;
}
@@ -403,8 +432,8 @@ void QLstmLayer::Accept(ILayerVisitor& visitor) const
ConstTensor cellToOutputWeightsTensor;
if (m_PeepholeParameters.m_CellToOutputWeights != nullptr)
{
- ConstTensor cellToOutputWeightsTensorCopy(m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo(),
- m_PeepholeParameters.m_CellToOutputWeights->Map(true));
+ ConstTensor cellToOutputWeightsTensorCopy(managedCellToOutputWeights.GetTensorInfo(),
+ managedCellToOutputWeights.Map());
cellToOutputWeightsTensor = cellToOutputWeightsTensorCopy;
inputParams.m_CellToOutputWeights = &cellToOutputWeightsTensor;
}
@@ -412,8 +441,8 @@ void QLstmLayer::Accept(ILayerVisitor& visitor) const
ConstTensor inputGateBiasTensor;
if (m_CifgParameters.m_InputGateBias != nullptr)
{
- ConstTensor inputGateBiasTensorCopy(m_CifgParameters.m_InputGateBias->GetTensorInfo(),
- m_CifgParameters.m_InputGateBias->Map(true));
+ ConstTensor inputGateBiasTensorCopy(managedInputGateBias.GetTensorInfo(),
+ managedInputGateBias.Map());
inputGateBiasTensor = inputGateBiasTensorCopy;
inputParams.m_InputGateBias = &inputGateBiasTensor;
}
@@ -421,8 +450,8 @@ void QLstmLayer::Accept(ILayerVisitor& visitor) const
ConstTensor forgetGateBiasTensor;
if (m_BasicParameters.m_ForgetGateBias != nullptr)
{
- ConstTensor forgetGateBiasTensorCopy(m_BasicParameters.m_ForgetGateBias->GetTensorInfo(),
- m_BasicParameters.m_ForgetGateBias->Map(true));
+ ConstTensor forgetGateBiasTensorCopy(managedForgetGateBias.GetTensorInfo(),
+ managedForgetGateBias.Map());
forgetGateBiasTensor = forgetGateBiasTensorCopy;
inputParams.m_ForgetGateBias = &forgetGateBiasTensor;
}
@@ -430,8 +459,8 @@ void QLstmLayer::Accept(ILayerVisitor& visitor) const
ConstTensor cellBiasTensor;
if (m_BasicParameters.m_CellBias != nullptr)
{
- ConstTensor cellBiasTensorCopy(m_BasicParameters.m_CellBias->GetTensorInfo(),
- m_BasicParameters.m_CellBias->Map(true));
+ ConstTensor cellBiasTensorCopy(managedCellBias.GetTensorInfo(),
+ managedCellBias.Map());
cellBiasTensor = cellBiasTensorCopy;
inputParams.m_CellBias = &cellBiasTensor;
}
@@ -439,8 +468,8 @@ void QLstmLayer::Accept(ILayerVisitor& visitor) const
ConstTensor outputGateBias;
if (m_BasicParameters.m_OutputGateBias != nullptr)
{
- ConstTensor outputGateBiasCopy(m_BasicParameters.m_OutputGateBias->GetTensorInfo(),
- m_BasicParameters.m_OutputGateBias->Map(true));
+ ConstTensor outputGateBiasCopy(managedOutputGateBias.GetTensorInfo(),
+ managedOutputGateBias.Map());
outputGateBias = outputGateBiasCopy;
inputParams.m_OutputGateBias = &outputGateBias;
}
@@ -448,8 +477,8 @@ void QLstmLayer::Accept(ILayerVisitor& visitor) const
ConstTensor projectionWeightsTensor;
if (m_ProjectionParameters.m_ProjectionWeights != nullptr)
{
- ConstTensor projectionWeightsTensorCopy(m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo(),
- m_ProjectionParameters.m_ProjectionWeights->Map(true));
+ ConstTensor projectionWeightsTensorCopy(managedProjectionWeights.GetTensorInfo(),
+ managedProjectionWeights.Map());
projectionWeightsTensor = projectionWeightsTensorCopy;
inputParams.m_ProjectionWeights = &projectionWeightsTensor;
}
@@ -457,8 +486,8 @@ void QLstmLayer::Accept(ILayerVisitor& visitor) const
ConstTensor projectionBiasTensor;
if (m_ProjectionParameters.m_ProjectionBias != nullptr)
{
- ConstTensor projectionBiasTensorCopy(m_ProjectionParameters.m_ProjectionBias->GetTensorInfo(),
- m_ProjectionParameters.m_ProjectionBias->Map(true));
+ ConstTensor projectionBiasTensorCopy(managedProjectionBias.GetTensorInfo(),
+ managedProjectionBias.Map());
projectionBiasTensor = projectionBiasTensorCopy;
inputParams.m_ProjectionBias = &projectionBiasTensor;
}
@@ -466,8 +495,8 @@ void QLstmLayer::Accept(ILayerVisitor& visitor) const
ConstTensor inputLayerNormTensor;
if (m_LayerNormParameters.m_InputLayerNormWeights != nullptr)
{
- ConstTensor inputLayerNormTensorCopy(m_LayerNormParameters.m_InputLayerNormWeights->GetTensorInfo(),
- m_LayerNormParameters.m_InputLayerNormWeights->Map(true));
+ ConstTensor inputLayerNormTensorCopy(managedInputLayerNormWeights.GetTensorInfo(),
+ managedInputLayerNormWeights.Map());
inputLayerNormTensor = inputLayerNormTensorCopy;
inputParams.m_InputLayerNormWeights = &inputLayerNormTensor;
}
@@ -475,8 +504,8 @@ void QLstmLayer::Accept(ILayerVisitor& visitor) const
ConstTensor forgetLayerNormTensor;
if (m_LayerNormParameters.m_ForgetLayerNormWeights != nullptr)
{
- ConstTensor forgetLayerNormTensorCopy(m_LayerNormParameters.m_ForgetLayerNormWeights->GetTensorInfo(),
- m_LayerNormParameters.m_ForgetLayerNormWeights->Map(true));
+ ConstTensor forgetLayerNormTensorCopy(managedForgetLayerNormWeights.GetTensorInfo(),
+ managedForgetLayerNormWeights.Map());
forgetLayerNormTensor = forgetLayerNormTensorCopy;
inputParams.m_ForgetLayerNormWeights = &forgetLayerNormTensor;
}
@@ -484,8 +513,8 @@ void QLstmLayer::Accept(ILayerVisitor& visitor) const
ConstTensor cellLayerNormTensor;
if (m_LayerNormParameters.m_CellLayerNormWeights != nullptr)
{
- ConstTensor cellLayerNormTensorCopy(m_LayerNormParameters.m_CellLayerNormWeights->GetTensorInfo(),
- m_LayerNormParameters.m_CellLayerNormWeights->Map(true));
+ ConstTensor cellLayerNormTensorCopy(managedCellLayerNormWeights.GetTensorInfo(),
+ managedCellLayerNormWeights.Map());
cellLayerNormTensor = cellLayerNormTensorCopy;
inputParams.m_CellLayerNormWeights = &cellLayerNormTensor;
}
@@ -493,8 +522,8 @@ void QLstmLayer::Accept(ILayerVisitor& visitor) const
ConstTensor outputLayerNormTensor;
if (m_LayerNormParameters.m_OutputLayerNormWeights != nullptr)
{
- ConstTensor outputLayerNormTensorCopy(m_LayerNormParameters.m_OutputLayerNormWeights->GetTensorInfo(),
- m_LayerNormParameters.m_OutputLayerNormWeights->Map(true));
+ ConstTensor outputLayerNormTensorCopy(managedOutputLayerNormWeights.GetTensorInfo(),
+ managedOutputLayerNormWeights.Map());
outputLayerNormTensor = outputLayerNormTensorCopy;
inputParams.m_OutputLayerNormWeights = &outputLayerNormTensor;
}
@@ -507,124 +536,153 @@ void QLstmLayer::Accept(ILayerVisitor& visitor) const
void QLstmLayer::ExecuteStrategy(IStrategy& strategy) const
{
std::vector<ConstTensor> constTensors;
+ ManagedConstTensorHandle managedInputToForgetWeights(m_BasicParameters.m_InputToForgetWeights);
+ ManagedConstTensorHandle managedInputToCellWeights(m_BasicParameters.m_InputToCellWeights);
+ ManagedConstTensorHandle managedInputToOutputWeights(m_BasicParameters.m_InputToOutputWeights);
+ ManagedConstTensorHandle managedRecurrentToForgetWeights(m_BasicParameters.m_RecurrentToForgetWeights);
+ ManagedConstTensorHandle managedRecurrentToCellWeights(m_BasicParameters.m_RecurrentToCellWeights);
+ ManagedConstTensorHandle managedRecurrentToOutputWeights(m_BasicParameters.m_RecurrentToOutputWeights);
+ ManagedConstTensorHandle managedForgetGateBias(m_BasicParameters.m_ForgetGateBias);
+ ManagedConstTensorHandle managedCellBias(m_BasicParameters.m_CellBias);
+ ManagedConstTensorHandle managedOutputGateBias(m_BasicParameters.m_OutputGateBias);
+
+ // Cifg parameters
+ ManagedConstTensorHandle managedInputToInputWeights(m_CifgParameters.m_InputToInputWeights);
+ ManagedConstTensorHandle managedRecurrentToInputWeights(m_CifgParameters.m_RecurrentToInputWeights);
+ ManagedConstTensorHandle managedInputGateBias(m_CifgParameters.m_InputGateBias);
+
+ // Projection parameters
+ ManagedConstTensorHandle managedProjectionWeights(m_ProjectionParameters.m_ProjectionWeights);
+ ManagedConstTensorHandle managedProjectionBias(m_ProjectionParameters.m_ProjectionBias);
+
+ // Peephole parameters
+ ManagedConstTensorHandle managedCellToInputWeights(m_PeepholeParameters.m_CellToInputWeights);
+ ManagedConstTensorHandle managedCellToForgetWeights(m_PeepholeParameters.m_CellToForgetWeights);
+ ManagedConstTensorHandle managedCellToOutputWeights(m_PeepholeParameters.m_CellToOutputWeights);
+
+ // Layer normalisation parameters
+ ManagedConstTensorHandle managedInputLayerNormWeights(m_LayerNormParameters.m_InputLayerNormWeights);
+ ManagedConstTensorHandle managedForgetLayerNormWeights(m_LayerNormParameters.m_ForgetLayerNormWeights);
+ ManagedConstTensorHandle managedCellLayerNormWeights(m_LayerNormParameters.m_CellLayerNormWeights);
+ ManagedConstTensorHandle managedOutputLayerNormWeights(m_LayerNormParameters.m_OutputLayerNormWeights);
// First add mandatory/basic parameters
if (m_BasicParameters.m_InputToForgetWeights != nullptr)
{
- constTensors.emplace_back(ConstTensor(m_BasicParameters.m_InputToForgetWeights->GetTensorInfo(),
- m_BasicParameters.m_InputToForgetWeights->Map(true)));
+ constTensors.emplace_back(ConstTensor(managedInputToForgetWeights.GetTensorInfo(),
+ managedInputToForgetWeights.Map()));
}
if (m_BasicParameters.m_InputToCellWeights != nullptr)
{
- constTensors.emplace_back(ConstTensor(m_BasicParameters.m_InputToCellWeights->GetTensorInfo(),
- m_BasicParameters.m_InputToCellWeights->Map(true)));
+ constTensors.emplace_back(ConstTensor(managedInputToCellWeights.GetTensorInfo(),
+ managedInputToCellWeights.Map()));
}
if (m_BasicParameters.m_InputToOutputWeights != nullptr)
{
- constTensors.emplace_back(ConstTensor(m_BasicParameters.m_InputToOutputWeights->GetTensorInfo(),
- m_BasicParameters.m_InputToOutputWeights->Map(true)));
+ constTensors.emplace_back(ConstTensor(managedInputToOutputWeights.GetTensorInfo(),
+ managedInputToOutputWeights.Map()));
}
if (m_BasicParameters.m_RecurrentToForgetWeights != nullptr)
{
constTensors.emplace_back(ConstTensor(
- m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo(),
- m_BasicParameters.m_RecurrentToForgetWeights->Map(true)));
+ managedRecurrentToForgetWeights.GetTensorInfo(),
+ managedRecurrentToForgetWeights.Map()));
}
if (m_BasicParameters.m_RecurrentToCellWeights != nullptr)
{
constTensors.emplace_back(ConstTensor(
- m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo(),
- m_BasicParameters.m_RecurrentToCellWeights->Map(true)));
+ managedRecurrentToCellWeights.GetTensorInfo(),
+ managedRecurrentToCellWeights.Map()));
}
if (m_BasicParameters.m_RecurrentToOutputWeights != nullptr)
{
constTensors.emplace_back(ConstTensor(
- m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo(),
- m_BasicParameters.m_RecurrentToOutputWeights->Map(true)));
+ managedRecurrentToOutputWeights.GetTensorInfo(),
+ managedRecurrentToOutputWeights.Map()));
}
if (m_BasicParameters.m_ForgetGateBias != nullptr)
{
- constTensors.emplace_back(ConstTensor(m_BasicParameters.m_ForgetGateBias->GetTensorInfo(),
- m_BasicParameters.m_ForgetGateBias->Map(true)));
+ constTensors.emplace_back(ConstTensor(managedForgetGateBias.GetTensorInfo(),
+ managedForgetGateBias.Map()));
}
if (m_BasicParameters.m_CellBias != nullptr)
{
- constTensors.emplace_back(ConstTensor(m_BasicParameters.m_CellBias->GetTensorInfo(),
- m_BasicParameters.m_CellBias->Map(true)));
+ constTensors.emplace_back(ConstTensor(managedCellBias.GetTensorInfo(),
+ managedCellBias.Map()));
}
if (m_BasicParameters.m_OutputGateBias != nullptr)
{
- constTensors.emplace_back(ConstTensor(m_BasicParameters.m_OutputGateBias->GetTensorInfo(),
- m_BasicParameters.m_OutputGateBias->Map(true)));
+ constTensors.emplace_back(ConstTensor(managedOutputGateBias.GetTensorInfo(),
+ managedOutputGateBias.Map()));
}
// Add cifig parameters
if (m_CifgParameters.m_InputToInputWeights != nullptr)
{
- constTensors.emplace_back(ConstTensor(m_CifgParameters.m_InputToInputWeights->GetTensorInfo(),
- m_CifgParameters.m_InputToInputWeights->Map(true)));
+ constTensors.emplace_back(ConstTensor(managedInputToInputWeights.GetTensorInfo(),
+ managedInputToInputWeights.Map()));
}
if (m_CifgParameters.m_RecurrentToInputWeights != nullptr)
{
constTensors.emplace_back(ConstTensor(
- m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo(),
- m_CifgParameters.m_RecurrentToInputWeights->Map(true)));
+ managedRecurrentToInputWeights.GetTensorInfo(),
+ managedRecurrentToInputWeights.Map()));
}
if (m_CifgParameters.m_InputGateBias != nullptr)
{
- constTensors.emplace_back(ConstTensor(m_CifgParameters.m_InputGateBias->GetTensorInfo(),
- m_CifgParameters.m_InputGateBias->Map(true)));
+ constTensors.emplace_back(ConstTensor(managedInputGateBias.GetTensorInfo(),
+ managedInputGateBias.Map()));
}
// Add peephole parameters
if (m_PeepholeParameters.m_CellToInputWeights != nullptr)
{
- constTensors.emplace_back(ConstTensor(m_PeepholeParameters.m_CellToInputWeights->GetTensorInfo(),
- m_PeepholeParameters.m_CellToInputWeights->Map(true)));
+ constTensors.emplace_back(ConstTensor(managedCellToInputWeights.GetTensorInfo(),
+ managedCellToInputWeights.Map()));
}
if (m_PeepholeParameters.m_CellToForgetWeights != nullptr)
{
- constTensors.emplace_back(ConstTensor(m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo(),
- m_PeepholeParameters.m_CellToForgetWeights->Map(true)));
+ constTensors.emplace_back(ConstTensor(managedCellToForgetWeights.GetTensorInfo(),
+ managedCellToForgetWeights.Map()));
}
if (m_PeepholeParameters.m_CellToOutputWeights != nullptr)
{
- constTensors.emplace_back(ConstTensor(m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo(),
- m_PeepholeParameters.m_CellToOutputWeights->Map(true)));
+ constTensors.emplace_back(ConstTensor(managedCellToOutputWeights.GetTensorInfo(),
+ managedCellToOutputWeights.Map()));
}
// Add projection parameters
if (m_ProjectionParameters.m_ProjectionWeights != nullptr)
{
- constTensors.emplace_back(ConstTensor(m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo(),
- m_ProjectionParameters.m_ProjectionWeights->Map(true)));
+ constTensors.emplace_back(ConstTensor(managedProjectionWeights.GetTensorInfo(),
+ managedProjectionWeights.Map()));
}
if (m_ProjectionParameters.m_ProjectionBias != nullptr)
{
- constTensors.emplace_back(ConstTensor(m_ProjectionParameters.m_ProjectionBias->GetTensorInfo(),
- m_ProjectionParameters.m_ProjectionBias->Map(true)));
+ constTensors.emplace_back(ConstTensor(managedProjectionBias.GetTensorInfo(),
+ managedProjectionBias.Map()));
}
// Add norm parameters
if (m_LayerNormParameters.m_InputLayerNormWeights != nullptr)
{
- constTensors.emplace_back(ConstTensor(m_LayerNormParameters.m_InputLayerNormWeights->GetTensorInfo(),
- m_LayerNormParameters.m_InputLayerNormWeights->Map(true)));
+ constTensors.emplace_back(ConstTensor(managedInputLayerNormWeights.GetTensorInfo(),
+ managedInputLayerNormWeights.Map()));
}
if (m_LayerNormParameters.m_ForgetLayerNormWeights != nullptr)
{
- constTensors.emplace_back(ConstTensor(m_LayerNormParameters.m_ForgetLayerNormWeights->GetTensorInfo(),
- m_LayerNormParameters.m_ForgetLayerNormWeights->Map(true)));
+ constTensors.emplace_back(ConstTensor(managedForgetLayerNormWeights.GetTensorInfo(),
+ managedForgetLayerNormWeights.Map()));
}
if (m_LayerNormParameters.m_CellLayerNormWeights != nullptr)
{
- constTensors.emplace_back(ConstTensor(m_LayerNormParameters.m_CellLayerNormWeights->GetTensorInfo(),
- m_LayerNormParameters.m_CellLayerNormWeights->Map(true)));
+ constTensors.emplace_back(ConstTensor(managedCellLayerNormWeights.GetTensorInfo(),
+ managedCellLayerNormWeights.Map()));
}
if (m_LayerNormParameters.m_OutputLayerNormWeights != nullptr)
{
- constTensors.emplace_back(ConstTensor(m_LayerNormParameters.m_OutputLayerNormWeights->GetTensorInfo(),
- m_LayerNormParameters.m_OutputLayerNormWeights->Map(true)));
+ constTensors.emplace_back(ConstTensor(managedOutputLayerNormWeights.GetTensorInfo(),
+ managedOutputLayerNormWeights.Map()));
}
strategy.ExecuteStrategy(this, GetParameters(), constTensors, GetName());
}