diff options
Diffstat (limited to 'src/armnn/Network.cpp')
-rw-r--r-- | src/armnn/Network.cpp | 38 |
1 files changed, 38 insertions, 0 deletions
diff --git a/src/armnn/Network.cpp b/src/armnn/Network.cpp index a43800827f..2195c71735 100644 --- a/src/armnn/Network.cpp +++ b/src/armnn/Network.cpp @@ -1458,6 +1458,44 @@ IConnectableLayer* Network::AddStackLayer(const StackDescriptor& stackDescriptor return m_Graph->AddLayer<StackLayer>(stackDescriptor, name); } +IConnectableLayer* Network::AddQuantizedLstmLayer(const QuantizedLstmInputParams& params, + const char* name) +{ + const auto layer = m_Graph->AddLayer<QuantizedLstmLayer>(name); + + // InputToX weights + layer->m_QuantizedLstmParameters.m_InputToInputWeights = + std::make_unique<ScopedCpuTensorHandle>(params.get_InputToInputWeights()); + layer->m_QuantizedLstmParameters.m_InputToForgetWeights = + std::make_unique<ScopedCpuTensorHandle>(params.get_InputToForgetWeights()); + layer->m_QuantizedLstmParameters.m_InputToCellWeights = + std::make_unique<ScopedCpuTensorHandle>(params.get_InputToCellWeights()); + layer->m_QuantizedLstmParameters.m_InputToOutputWeights = + std::make_unique<ScopedCpuTensorHandle>(params.get_InputToOutputWeights()); + + // RecurrentToX weights + layer->m_QuantizedLstmParameters.m_RecurrentToInputWeights = + std::make_unique<ScopedCpuTensorHandle>(params.get_RecurrentToInputWeights()); + layer->m_QuantizedLstmParameters.m_RecurrentToForgetWeights = + std::make_unique<ScopedCpuTensorHandle>(params.get_RecurrentToForgetWeights()); + layer->m_QuantizedLstmParameters.m_RecurrentToCellWeights = + std::make_unique<ScopedCpuTensorHandle>(params.get_RecurrentToCellWeights()); + layer->m_QuantizedLstmParameters.m_RecurrentToOutputWeights = + std::make_unique<ScopedCpuTensorHandle>(params.get_RecurrentToOutputWeights()); + + // Bias + layer->m_QuantizedLstmParameters.m_InputGateBias = + std::make_unique<ScopedCpuTensorHandle>(params.get_InputGateBias()); + layer->m_QuantizedLstmParameters.m_ForgetGateBias = + std::make_unique<ScopedCpuTensorHandle>(params.get_ForgetGateBias()); + layer->m_QuantizedLstmParameters.m_CellBias = + std::make_unique<ScopedCpuTensorHandle>(params.get_CellBias()); + layer->m_QuantizedLstmParameters.m_OutputGateBias = + std::make_unique<ScopedCpuTensorHandle>(params.get_OutputGateBias()); + + return layer; +} + void Network::Accept(ILayerVisitor& visitor) const { for (auto layer : GetGraph()) |