diff options
Diffstat (limited to 'src/armnn/layers/LstmLayer.cpp')
-rw-r--r-- | src/armnn/layers/LstmLayer.cpp | 146 |
1 files changed, 146 insertions, 0 deletions
diff --git a/src/armnn/layers/LstmLayer.cpp b/src/armnn/layers/LstmLayer.cpp index 8e396ab70c..ebc408a636 100644 --- a/src/armnn/layers/LstmLayer.cpp +++ b/src/armnn/layers/LstmLayer.cpp @@ -480,4 +480,150 @@ void LstmLayer::Accept(ILayerVisitor& visitor) const visitor.VisitLstmLayer(this, GetParameters(), inputParams, GetName()); } +void LstmLayer::ExecuteStrategy(IStrategy& strategy) const +{ + std::vector<ConstTensor> constTensors; + + LstmDescriptor descriptor = GetParameters(); + + // First add mandatory/basic parameters + if (m_BasicParameters.m_InputToForgetWeights != nullptr) + { + constTensors.emplace_back(ConstTensor(m_BasicParameters.m_InputToForgetWeights->GetTensorInfo(), + m_BasicParameters.m_InputToForgetWeights->Map(true))); + } + if (m_BasicParameters.m_InputToCellWeights != nullptr) + { + constTensors.emplace_back(ConstTensor(m_BasicParameters.m_InputToCellWeights->GetTensorInfo(), + m_BasicParameters.m_InputToCellWeights->Map(true))); + } + if (m_BasicParameters.m_InputToOutputWeights != nullptr) + { + constTensors.emplace_back(ConstTensor(m_BasicParameters.m_InputToOutputWeights->GetTensorInfo(), + m_BasicParameters.m_InputToOutputWeights->Map(true))); + } + if (m_BasicParameters.m_RecurrentToForgetWeights != nullptr) + { + constTensors.emplace_back(ConstTensor( + m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo(), + m_BasicParameters.m_RecurrentToForgetWeights->Map(true))); + } + if (m_BasicParameters.m_RecurrentToCellWeights != nullptr) + { + constTensors.emplace_back(ConstTensor( + m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo(), + m_BasicParameters.m_RecurrentToCellWeights->Map(true))); + } + if (m_BasicParameters.m_RecurrentToOutputWeights != nullptr) + { + constTensors.emplace_back(ConstTensor( + m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo(), + m_BasicParameters.m_RecurrentToOutputWeights->Map(true))); + } + if (m_BasicParameters.m_ForgetGateBias != nullptr) + { + constTensors.emplace_back(ConstTensor(m_BasicParameters.m_ForgetGateBias->GetTensorInfo(), + m_BasicParameters.m_ForgetGateBias->Map(true))); + } + if (m_BasicParameters.m_CellBias != nullptr) + { + constTensors.emplace_back(ConstTensor(m_BasicParameters.m_CellBias->GetTensorInfo(), + m_BasicParameters.m_CellBias->Map(true))); + } + if (m_BasicParameters.m_OutputGateBias != nullptr) + { + constTensors.emplace_back(ConstTensor(m_BasicParameters.m_OutputGateBias->GetTensorInfo(), + m_BasicParameters.m_OutputGateBias->Map(true))); + } + + // Add cifg parameters + if (!descriptor.m_CifgEnabled) + { + if (m_CifgParameters.m_InputToInputWeights != nullptr) + { + constTensors.emplace_back(ConstTensor(m_CifgParameters.m_InputToInputWeights->GetTensorInfo(), + m_CifgParameters.m_InputToInputWeights->Map(true))); + } + if (m_CifgParameters.m_RecurrentToInputWeights != nullptr) + { + constTensors.emplace_back(ConstTensor( + m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo(), + m_CifgParameters.m_RecurrentToInputWeights->Map(true))); + } + if (m_CifgParameters.m_InputGateBias != nullptr) + { + constTensors.emplace_back(ConstTensor(m_CifgParameters.m_InputGateBias->GetTensorInfo(), + m_CifgParameters.m_InputGateBias->Map(true))); + } + } + + // Add peephole parameters + if (descriptor.m_PeepholeEnabled) + { + if (!descriptor.m_CifgEnabled) + { + if (m_PeepholeParameters.m_CellToInputWeights != nullptr) + { + constTensors.emplace_back(ConstTensor(m_PeepholeParameters.m_CellToInputWeights->GetTensorInfo(), + m_PeepholeParameters.m_CellToInputWeights->Map(true))); + } + } + if (m_PeepholeParameters.m_CellToForgetWeights != nullptr) + { + constTensors.emplace_back(ConstTensor(m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo(), + m_PeepholeParameters.m_CellToForgetWeights->Map(true))); + } + if (m_PeepholeParameters.m_CellToOutputWeights != nullptr) + { + constTensors.emplace_back(ConstTensor(m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo(), + m_PeepholeParameters.m_CellToOutputWeights->Map(true))); + } + } + + // Add projection parameters + if (descriptor.m_ProjectionEnabled) + { + if (m_ProjectionParameters.m_ProjectionWeights != nullptr) + { + constTensors.emplace_back(ConstTensor(m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo(), + m_ProjectionParameters.m_ProjectionWeights->Map(true))); + } + if (m_ProjectionParameters.m_ProjectionBias != nullptr) + { + constTensors.emplace_back(ConstTensor(m_ProjectionParameters.m_ProjectionBias->GetTensorInfo(), + m_ProjectionParameters.m_ProjectionBias->Map(true))); + } + } + + // Add norm parameters + if (descriptor.m_LayerNormEnabled) + { + if (!descriptor.m_CifgEnabled) + { + if (m_LayerNormParameters.m_InputLayerNormWeights != nullptr) + { + constTensors.emplace_back(ConstTensor(m_LayerNormParameters.m_InputLayerNormWeights->GetTensorInfo(), + m_LayerNormParameters.m_InputLayerNormWeights->Map(true))); + } + } + if (m_LayerNormParameters.m_ForgetLayerNormWeights != nullptr) + { + constTensors.emplace_back(ConstTensor(m_LayerNormParameters.m_ForgetLayerNormWeights->GetTensorInfo(), + m_LayerNormParameters.m_ForgetLayerNormWeights->Map(true))); + } + if (m_LayerNormParameters.m_CellLayerNormWeights != nullptr) + { + constTensors.emplace_back(ConstTensor(m_LayerNormParameters.m_CellLayerNormWeights->GetTensorInfo(), + m_LayerNormParameters.m_CellLayerNormWeights->Map(true))); + } + if (m_LayerNormParameters.m_OutputLayerNormWeights != nullptr) + { + constTensors.emplace_back(ConstTensor(m_LayerNormParameters.m_OutputLayerNormWeights->GetTensorInfo(), + m_LayerNormParameters.m_OutputLayerNormWeights->Map(true))); + } + } + + strategy.ExecuteStrategy(this, GetParameters(), constTensors, GetName()); +} + } // namespace armnn |