diff options
author | Narumol Prangnawarat <narumol.prangnawarat@arm.com> | 2021-07-15 16:16:25 +0100 |
---|---|---|
committer | Narumol Prangnawarat <narumol.prangnawarat@arm.com> | 2021-07-22 18:29:55 +0100 |
commit | 8ed39ae450a077c7e4d672b5f05ff1d68ee67aab (patch) | |
tree | 31a1cf006e50db54f3e7a605825c8e9e3f9d689e /src/armnn/Network.cpp | |
parent | 15fcc7ed3163c9d4b1856955271854198c3c2696 (diff) | |
download | armnn-8ed39ae450a077c7e4d672b5f05ff1d68ee67aab.tar.gz |
MLCE-530 Add front end support for UnidirectionalSequenceLstm on ArmNN
Signed-off-by: Narumol Prangnawarat <narumol.prangnawarat@arm.com>
Change-Id: I57bcbdec3eb0155f41af0fe7d6abf9bac2ec86eb
Diffstat (limited to 'src/armnn/Network.cpp')
-rw-r--r-- | src/armnn/Network.cpp | 152 |
1 files changed, 151 insertions, 1 deletions
diff --git a/src/armnn/Network.cpp b/src/armnn/Network.cpp index d340f021e2..83eafe7993 100644 --- a/src/armnn/Network.cpp +++ b/src/armnn/Network.cpp @@ -518,6 +518,14 @@ IConnectableLayer* INetwork::AddLogicalBinaryLayer(const LogicalBinaryDescriptor return pNetworkImpl->AddLogicalBinaryLayer(descriptor, name); } +IConnectableLayer* INetwork::AddUnidirectionalSequenceLstmLayer( + const UnidirectionalSequenceLstmDescriptor& descriptor, + const LstmInputParams& params, + const char* name) +{ + return pNetworkImpl->AddUnidirectionalSequenceLstmLayer(descriptor, params, name); +} + void INetwork::Accept(ILayerVisitor& visitor) const { return pNetworkImpl->Accept(visitor); @@ -2603,11 +2611,153 @@ IConnectableLayer* NetworkImpl::AddQLstmLayer(const QLstmDescriptor& descriptor } IConnectableLayer* NetworkImpl::AddLogicalBinaryLayer(const LogicalBinaryDescriptor& logicalBinaryDescriptor, - const char* name) + const char* name) { return m_Graph->AddLayer<LogicalBinaryLayer>(logicalBinaryDescriptor, name); } +IConnectableLayer* NetworkImpl::AddUnidirectionalSequenceLstmLayer( + const UnidirectionalSequenceLstmDescriptor& descriptor, + const LstmInputParams& params, + const char* name) +{ + const auto layer = m_Graph->AddLayer<UnidirectionalSequenceLstmLayer>(descriptor, name); + + //Lstm Basic Parameters + layer->m_BasicParameters.m_InputToForgetWeights = + std::make_shared<ScopedTensorHandle>(*(params.m_InputToForgetWeights)); + layer->m_BasicParameters.m_InputToCellWeights = + std::make_shared<ScopedTensorHandle>(*(params.m_InputToCellWeights)); + layer->m_BasicParameters.m_InputToOutputWeights = + std::make_shared<ScopedTensorHandle>(*(params.m_InputToOutputWeights)); + layer->m_BasicParameters.m_RecurrentToForgetWeights = + std::make_shared<ScopedTensorHandle>(*(params.m_RecurrentToForgetWeights)); + layer->m_BasicParameters.m_RecurrentToCellWeights = + std::make_shared<ScopedTensorHandle>(*(params.m_RecurrentToCellWeights)); + layer->m_BasicParameters.m_RecurrentToOutputWeights = + std::make_shared<ScopedTensorHandle>(*(params.m_RecurrentToOutputWeights)); + layer->m_BasicParameters.m_ForgetGateBias = + std::make_shared<ScopedTensorHandle>(*(params.m_ForgetGateBias)); + layer->m_BasicParameters.m_CellBias = + std::make_shared<ScopedTensorHandle>(*(params.m_CellBias)); + layer->m_BasicParameters.m_OutputGateBias = + std::make_shared<ScopedTensorHandle>(*(params.m_OutputGateBias)); + + //Lstm Cifg parameters + if(!descriptor.m_CifgEnabled) + { + if(params.m_InputToInputWeights == nullptr) + { + throw InvalidArgumentException("AddUnidirectionalSequenceLstmLayer: Input To Input Weights cannot be NULL " + "when CIFG is disabled."); + } + if(params.m_RecurrentToInputWeights == nullptr) + { + throw InvalidArgumentException( + "AddUnidirectionalSequenceLstmLayer: Recurrent To Input Weights cannot be NULL " + "when CIFG is disabled."); + } + if(params.m_InputGateBias == nullptr) + { + throw InvalidArgumentException("AddUnidirectionalSequenceLstmLayer: Input Gate Bias cannot be NULL " + "when CIFG is disabled."); + } + layer->m_CifgParameters.m_InputToInputWeights = + std::make_shared<ScopedTensorHandle>(*(params.m_InputToInputWeights)); + layer->m_CifgParameters.m_RecurrentToInputWeights = + std::make_shared<ScopedTensorHandle>(*(params.m_RecurrentToInputWeights)); + layer->m_CifgParameters.m_InputGateBias = + std::make_shared<ScopedTensorHandle>(*(params.m_InputGateBias)); + } + + //Lstm projection parameters + if(descriptor.m_ProjectionEnabled) + { + if(params.m_ProjectionWeights == nullptr) + { + throw InvalidArgumentException("AddUnidirectionalSequenceLstmLayer: Projection Weights cannot be NULL " + "when projection is enabled."); + } + layer->m_ProjectionParameters.m_ProjectionWeights = + std::make_shared<ScopedTensorHandle>(*(params.m_ProjectionWeights)); + if(params.m_ProjectionBias != nullptr) + { + layer->m_ProjectionParameters.m_ProjectionBias = + std::make_shared<ScopedTensorHandle>(*(params.m_ProjectionBias)); + } + } + + //Lstm Peephole params + if(descriptor.m_PeepholeEnabled) + { + if(!descriptor.m_CifgEnabled) + { + if(params.m_CellToInputWeights == nullptr) + { + throw InvalidArgumentException("AddUnidirectionalSequenceLstmLayer: Cell To Input Weights " + "cannot be NULL when Peephole is enabled and CIFG disabled."); + } + + layer->m_PeepholeParameters.m_CellToInputWeights = + std::make_shared<ScopedTensorHandle>(*(params.m_CellToInputWeights)); + } + + if(params.m_CellToForgetWeights == nullptr) + { + throw InvalidArgumentException("AddUnidirectionalSequenceLstmLayer: Cell To Forget Weights cannot be NULL " + "when Peephole is enabled."); + } + if(params.m_CellToOutputWeights == nullptr) + { + throw InvalidArgumentException("AddUnidirectionalSequenceLstmLayer: Cell To Output Weights cannot be NULL " + "when Peephole is enabled."); + } + + layer->m_PeepholeParameters.m_CellToForgetWeights = + std::make_shared<ScopedTensorHandle>(*(params.m_CellToForgetWeights)); + layer->m_PeepholeParameters.m_CellToOutputWeights = + std::make_shared<ScopedTensorHandle>(*(params.m_CellToOutputWeights)); + } + + //Lstm Layer Normalization params + if(descriptor.m_LayerNormEnabled) + { + if(!descriptor.m_CifgEnabled) + { + if(params.m_InputLayerNormWeights == nullptr) + { + throw InvalidArgumentException("AddUnidirectionalSequenceLstmLayer: Input layer normalization weights " + "cannot be NULL when layer normalization is enabled and CIFG disabled."); + } + layer->m_LayerNormParameters.m_InputLayerNormWeights = + std::make_shared<ScopedTensorHandle>(*(params.m_InputLayerNormWeights)); + } + + if(params.m_ForgetLayerNormWeights == nullptr) + { + throw InvalidArgumentException("AddUnidirectionalSequenceLstmLayer: Forget layer normalization weights " + "cannot be NULL when layer normalization is enabled."); + } + if(params.m_CellLayerNormWeights == nullptr) + { + throw InvalidArgumentException("AddUnidirectionalSequenceLstmLayer: Cell layer normalization weights " + "cannot be NULL when layer normalization is enabled."); + } + if(params.m_OutputLayerNormWeights == nullptr) + { + throw InvalidArgumentException("AddUnidirectionalSequenceLstmLayer: Output layer normalization weights " + "cannot be NULL when layer normalization is enabled."); + } + layer->m_LayerNormParameters.m_ForgetLayerNormWeights = + std::make_shared<ScopedTensorHandle>(*(params.m_ForgetLayerNormWeights)); + layer->m_LayerNormParameters.m_CellLayerNormWeights = + std::make_shared<ScopedTensorHandle>(*(params.m_CellLayerNormWeights)); + layer->m_LayerNormParameters.m_OutputLayerNormWeights = + std::make_shared<ScopedTensorHandle>(*(params.m_OutputLayerNormWeights)); + } + return layer; +} + void NetworkImpl::Accept(ILayerVisitor& visitor) const { for (auto layer : GetGraph()) |