aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/layers/LstmLayer.cpp
diff options
context:
space:
mode:
authorJan Eilers <jan.eilers@arm.com>2019-06-26 13:10:09 +0100
committerJan Eilers <jan.eilers@arm.com>2019-07-02 09:59:37 +0000
commit38e05bd2836b1b65b440330a9c283038ba4192c3 (patch)
treec232f71ce6a101c70ed65e046678f7b22593dbe4 /src/armnn/layers/LstmLayer.cpp
parentd0c0cc3e27f1ada9df167d3b9ff248be432d16e1 (diff)
downloadarmnn-38e05bd2836b1b65b440330a9c283038ba4192c3.tar.gz
IVGCVSW-3236 Extend Ref LSTM with layer normalization support
* Add descriptor values * Update lstm queue descriptor validate function * Update lstm workload * Update isLstmSupported (Cl and Ref), LayerSupportBase, ILayerSupport * Update lstm layer * Add unit tests Signed-off-by: Jan Eilers <jan.eilers@arm.com> Change-Id: I932175d550facfb342325051eaa7bd2084ebdc18 Signed-off-by: Jan Eilers <jan.eilers@arm.com>
Diffstat (limited to 'src/armnn/layers/LstmLayer.cpp')
-rw-r--r--src/armnn/layers/LstmLayer.cpp81
1 files changed, 80 insertions, 1 deletions
diff --git a/src/armnn/layers/LstmLayer.cpp b/src/armnn/layers/LstmLayer.cpp
index 2b99f284e8..4012839dfe 100644
--- a/src/armnn/layers/LstmLayer.cpp
+++ b/src/armnn/layers/LstmLayer.cpp
@@ -55,6 +55,19 @@ std::unique_ptr<IWorkload> LstmLayer::CreateWorkload(const Graph& graph, const I
descriptor.m_CellToForgetWeights = m_PeepholeParameters.m_CellToForgetWeights.get();
descriptor.m_CellToOutputWeights = m_PeepholeParameters.m_CellToOutputWeights.get();
}
+
+ // Layer normalisation parameters
+ if(m_Param.m_LayerNormEnabled)
+ {
+ if (!m_Param.m_CifgEnabled)
+ {
+ descriptor.m_InputLayerNormWeights = m_LayerNormParameters.m_InputLayerNormWeights.get();
+ }
+ descriptor.m_ForgetLayerNormWeights = m_LayerNormParameters.m_ForgetLayerNormWeights.get();
+ descriptor.m_CellLayerNormWeights = m_LayerNormParameters.m_CellLayerNormWeights.get();
+ descriptor.m_OutputLayerNormWeights = m_LayerNormParameters.m_OutputLayerNormWeights.get();
+ }
+
return factory.CreateLstm(descriptor, PrepInfoAndDesc(descriptor, graph));
}
@@ -110,6 +123,18 @@ LstmLayer* LstmLayer::Clone(Graph& graph) const
std::make_unique<ScopedCpuTensorHandle>(*m_PeepholeParameters.m_CellToOutputWeights) : nullptr;
}
+ if (m_Param.m_LayerNormEnabled)
+ {
+ layer->m_LayerNormParameters.m_InputLayerNormWeights = m_LayerNormParameters.m_InputLayerNormWeights ?
+ std::make_unique<ScopedCpuTensorHandle>(*m_LayerNormParameters.m_InputLayerNormWeights) : nullptr;
+ layer->m_LayerNormParameters.m_ForgetLayerNormWeights = m_LayerNormParameters.m_ForgetLayerNormWeights ?
+ std::make_unique<ScopedCpuTensorHandle>(*m_LayerNormParameters.m_ForgetLayerNormWeights) : nullptr;
+ layer->m_LayerNormParameters.m_CellLayerNormWeights = m_LayerNormParameters.m_CellLayerNormWeights ?
+ std::make_unique<ScopedCpuTensorHandle>(*m_LayerNormParameters.m_CellLayerNormWeights) : nullptr;
+ layer->m_LayerNormParameters.m_OutputLayerNormWeights = m_LayerNormParameters.m_OutputLayerNormWeights ?
+ std::make_unique<ScopedCpuTensorHandle>(*m_LayerNormParameters.m_OutputLayerNormWeights) : nullptr;
+ }
+
return std::move(layer);
}
@@ -220,6 +245,21 @@ void LstmLayer::ValidateTensorShapesFromInputs()
"LstmLayer: TensorShape set on OutputSlot[3] does not match the inferred shape.",
GetOutputSlot(3).GetTensorInfo().GetShape(),
inferredShapes[3]);
+
+ if (m_Param.m_LayerNormEnabled)
+ {
+ if(!m_Param.m_CifgEnabled)
+ {
+ BOOST_ASSERT_MSG(m_LayerNormParameters.m_InputLayerNormWeights != nullptr,
+ "LstmLayer: m_LayerNormParameters.m_inputLayerNormWeights should not be null.");
+ }
+ BOOST_ASSERT_MSG(m_LayerNormParameters.m_ForgetLayerNormWeights != nullptr,
+ "LstmLayer: m_LayerNormParameters.m_forgetLayerNormWeights should not be null.");
+ BOOST_ASSERT_MSG(m_LayerNormParameters.m_CellLayerNormWeights != nullptr,
+ "LstmLayer: m_LayerNormParameters.m_cellLayerNormWeights should not be null.");
+ BOOST_ASSERT_MSG(m_LayerNormParameters.m_OutputLayerNormWeights != nullptr,
+ "LstmLayer: m_LayerNormParameters.m_outputLayerNormWeights should not be null.");
+ }
}
Layer::ConstantTensors LstmLayer::GetConstantTensorsByRef()
@@ -246,7 +286,13 @@ Layer::ConstantTensors LstmLayer::GetConstantTensorsByRef()
// Peephole parameters
m_PeepholeParameters.m_CellToForgetWeights,
- m_PeepholeParameters.m_CellToOutputWeights};
+ m_PeepholeParameters.m_CellToOutputWeights,
+
+ // Layer normalisation parameters
+ m_LayerNormParameters.m_InputLayerNormWeights,
+ m_LayerNormParameters.m_ForgetLayerNormWeights,
+ m_LayerNormParameters.m_CellLayerNormWeights,
+ m_LayerNormParameters.m_OutputLayerNormWeights};
}
void LstmLayer::Accept(ILayerVisitor& visitor) const
@@ -392,6 +438,39 @@ void LstmLayer::Accept(ILayerVisitor& visitor) const
projectionBiasTensor = projectionBiasTensorCopy;
inputParams.m_ProjectionBias = &projectionBiasTensor;
}
+ ConstTensor inputLayerNormTensor;
+ if (m_LayerNormParameters.m_InputLayerNormWeights != nullptr)
+ {
+ ConstTensor inputLayerNormTensorCopy(m_LayerNormParameters.m_InputLayerNormWeights->GetTensorInfo(),
+ m_LayerNormParameters.m_InputLayerNormWeights->Map(true));
+ inputLayerNormTensor = inputLayerNormTensorCopy;
+ inputParams.m_InputLayerNormWeights = &inputLayerNormTensor;
+ }
+ ConstTensor forgetLayerNormTensor;
+ if (m_LayerNormParameters.m_ForgetLayerNormWeights != nullptr)
+ {
+ ConstTensor forgetLayerNormTensorCopy(m_LayerNormParameters.m_ForgetLayerNormWeights->GetTensorInfo(),
+ m_LayerNormParameters.m_ForgetLayerNormWeights->Map(true));
+ forgetLayerNormTensor = forgetLayerNormTensorCopy;
+ inputParams.m_ForgetLayerNormWeights = &forgetLayerNormTensor;
+ }
+ ConstTensor cellLayerNormTensor;
+ if (m_LayerNormParameters.m_CellLayerNormWeights != nullptr)
+ {
+ ConstTensor cellLayerNormTensorCopy(m_LayerNormParameters.m_CellLayerNormWeights->GetTensorInfo(),
+ m_LayerNormParameters.m_CellLayerNormWeights->Map(true));
+ cellLayerNormTensor = cellLayerNormTensorCopy;
+ inputParams.m_CellLayerNormWeights = &cellLayerNormTensor;
+ }
+ ConstTensor outputLayerNormTensor;
+ if (m_LayerNormParameters.m_OutputLayerNormWeights != nullptr)
+ {
+ ConstTensor outputLayerNormTensorCopy(m_LayerNormParameters.m_OutputLayerNormWeights->GetTensorInfo(),
+ m_LayerNormParameters.m_OutputLayerNormWeights->Map(true));
+ outputLayerNormTensor = outputLayerNormTensorCopy;
+ inputParams.m_OutputLayerNormWeights = &outputLayerNormTensor;
+ }
+
visitor.VisitLstmLayer(this, GetParameters(), inputParams, GetName());
}