aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/Network.cpp
diff options
context:
space:
mode:
authorJan Eilers <jan.eilers@arm.com>2019-07-17 11:07:49 +0100
committerJan Eilers <jan.eilers@arm.com>2019-07-23 11:01:29 +0100
commitf8c629701e760e2476582c91e6b7f5e1313dc02a (patch)
tree3c9dc1a673fe99367a7f25d98199d8ab386e8bc2 /src/armnn/Network.cpp
parent07f2121feeeeae36a7e67eeb8a6965df63b848f3 (diff)
downloadarmnn-f8c629701e760e2476582c91e6b7f5e1313dc02a.tar.gz
IVGCVSW-3526 Add layer norm support for lstm serialization
* Adds layer norm support for serialization/deserialization * Adds related unit tests Change-Id: If80b668accc8b0754a93d18ab3a243284cb383d1 Signed-off-by: Jan Eilers <jan.eilers@arm.com>
Diffstat (limited to 'src/armnn/Network.cpp')
-rw-r--r--src/armnn/Network.cpp33
1 files changed, 33 insertions, 0 deletions
diff --git a/src/armnn/Network.cpp b/src/armnn/Network.cpp
index 29493816a8..6707cc7a26 100644
--- a/src/armnn/Network.cpp
+++ b/src/armnn/Network.cpp
@@ -1326,6 +1326,39 @@ IConnectableLayer* Network::AddLstmLayer(const LstmDescriptor& descriptor,
layer->m_PeepholeParameters.m_CellToOutputWeights =
std::make_unique<ScopedCpuTensorHandle>(*(params.m_CellToOutputWeights));
}
+
+ //Lstm Layer Normalization params
+ if(descriptor.m_LayerNormEnabled)
+ {
+ if(!descriptor.m_CifgEnabled)
+ {
+ if(params.m_InputLayerNormWeights == nullptr)
+ {
+ throw InvalidArgumentException("AddLstmLayer: Input layer normalization weights cannot be NULL");
+ }
+ layer->m_LayerNormParameters.m_InputLayerNormWeights =
+ std::make_unique<ScopedCpuTensorHandle>(*(params.m_InputLayerNormWeights));
+ }
+
+ if(params.m_ForgetLayerNormWeights == nullptr)
+ {
+ throw InvalidArgumentException("AddLstmLayer: Forget layer normalization weights cannot be NULL");
+ }
+ if(params.m_CellLayerNormWeights == nullptr)
+ {
+ throw InvalidArgumentException("AddLstmLayer: Cell layer normalization weights cannot be NULL");
+ }
+ if(params.m_OutputLayerNormWeights == nullptr)
+ {
+ throw InvalidArgumentException("AddLstmLayer: Output layer normalization weights cannot be NULL");
+ }
+ layer->m_LayerNormParameters.m_ForgetLayerNormWeights =
+ std::make_unique<ScopedCpuTensorHandle>(*(params.m_ForgetLayerNormWeights));
+ layer->m_LayerNormParameters.m_CellLayerNormWeights =
+ std::make_unique<ScopedCpuTensorHandle>(*(params.m_CellLayerNormWeights));
+ layer->m_LayerNormParameters.m_OutputLayerNormWeights =
+ std::make_unique<ScopedCpuTensorHandle>(*(params.m_OutputLayerNormWeights));
+ }
return layer;
}