aboutsummaryrefslogtreecommitdiff
path: root/src/armnn
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn')
-rw-r--r--src/armnn/layers/LstmLayer.cpp81
-rw-r--r--src/armnn/layers/LstmLayer.hpp13
2 files changed, 93 insertions, 1 deletions
diff --git a/src/armnn/layers/LstmLayer.cpp b/src/armnn/layers/LstmLayer.cpp
index 2b99f284e8..4012839dfe 100644
--- a/src/armnn/layers/LstmLayer.cpp
+++ b/src/armnn/layers/LstmLayer.cpp
@@ -55,6 +55,19 @@ std::unique_ptr<IWorkload> LstmLayer::CreateWorkload(const Graph& graph, const I
descriptor.m_CellToForgetWeights = m_PeepholeParameters.m_CellToForgetWeights.get();
descriptor.m_CellToOutputWeights = m_PeepholeParameters.m_CellToOutputWeights.get();
}
+
+ // Layer normalisation parameters
+ if(m_Param.m_LayerNormEnabled)
+ {
+ if (!m_Param.m_CifgEnabled)
+ {
+ descriptor.m_InputLayerNormWeights = m_LayerNormParameters.m_InputLayerNormWeights.get();
+ }
+ descriptor.m_ForgetLayerNormWeights = m_LayerNormParameters.m_ForgetLayerNormWeights.get();
+ descriptor.m_CellLayerNormWeights = m_LayerNormParameters.m_CellLayerNormWeights.get();
+ descriptor.m_OutputLayerNormWeights = m_LayerNormParameters.m_OutputLayerNormWeights.get();
+ }
+
return factory.CreateLstm(descriptor, PrepInfoAndDesc(descriptor, graph));
}
@@ -110,6 +123,18 @@ LstmLayer* LstmLayer::Clone(Graph& graph) const
std::make_unique<ScopedCpuTensorHandle>(*m_PeepholeParameters.m_CellToOutputWeights) : nullptr;
}
+ if (m_Param.m_LayerNormEnabled)
+ {
+ layer->m_LayerNormParameters.m_InputLayerNormWeights = m_LayerNormParameters.m_InputLayerNormWeights ?
+ std::make_unique<ScopedCpuTensorHandle>(*m_LayerNormParameters.m_InputLayerNormWeights) : nullptr;
+ layer->m_LayerNormParameters.m_ForgetLayerNormWeights = m_LayerNormParameters.m_ForgetLayerNormWeights ?
+ std::make_unique<ScopedCpuTensorHandle>(*m_LayerNormParameters.m_ForgetLayerNormWeights) : nullptr;
+ layer->m_LayerNormParameters.m_CellLayerNormWeights = m_LayerNormParameters.m_CellLayerNormWeights ?
+ std::make_unique<ScopedCpuTensorHandle>(*m_LayerNormParameters.m_CellLayerNormWeights) : nullptr;
+ layer->m_LayerNormParameters.m_OutputLayerNormWeights = m_LayerNormParameters.m_OutputLayerNormWeights ?
+ std::make_unique<ScopedCpuTensorHandle>(*m_LayerNormParameters.m_OutputLayerNormWeights) : nullptr;
+ }
+
return std::move(layer);
}
@@ -220,6 +245,21 @@ void LstmLayer::ValidateTensorShapesFromInputs()
"LstmLayer: TensorShape set on OutputSlot[3] does not match the inferred shape.",
GetOutputSlot(3).GetTensorInfo().GetShape(),
inferredShapes[3]);
+
+ if (m_Param.m_LayerNormEnabled)
+ {
+ if(!m_Param.m_CifgEnabled)
+ {
+ BOOST_ASSERT_MSG(m_LayerNormParameters.m_InputLayerNormWeights != nullptr,
+ "LstmLayer: m_LayerNormParameters.m_inputLayerNormWeights should not be null.");
+ }
+ BOOST_ASSERT_MSG(m_LayerNormParameters.m_ForgetLayerNormWeights != nullptr,
+ "LstmLayer: m_LayerNormParameters.m_forgetLayerNormWeights should not be null.");
+ BOOST_ASSERT_MSG(m_LayerNormParameters.m_CellLayerNormWeights != nullptr,
+ "LstmLayer: m_LayerNormParameters.m_cellLayerNormWeights should not be null.");
+ BOOST_ASSERT_MSG(m_LayerNormParameters.m_OutputLayerNormWeights != nullptr,
+ "LstmLayer: m_LayerNormParameters.m_outputLayerNormWeights should not be null.");
+ }
}
Layer::ConstantTensors LstmLayer::GetConstantTensorsByRef()
@@ -246,7 +286,13 @@ Layer::ConstantTensors LstmLayer::GetConstantTensorsByRef()
// Peephole parameters
m_PeepholeParameters.m_CellToForgetWeights,
- m_PeepholeParameters.m_CellToOutputWeights};
+ m_PeepholeParameters.m_CellToOutputWeights,
+
+ // Layer normalisation parameters
+ m_LayerNormParameters.m_InputLayerNormWeights,
+ m_LayerNormParameters.m_ForgetLayerNormWeights,
+ m_LayerNormParameters.m_CellLayerNormWeights,
+ m_LayerNormParameters.m_OutputLayerNormWeights};
}
void LstmLayer::Accept(ILayerVisitor& visitor) const
@@ -392,6 +438,39 @@ void LstmLayer::Accept(ILayerVisitor& visitor) const
projectionBiasTensor = projectionBiasTensorCopy;
inputParams.m_ProjectionBias = &projectionBiasTensor;
}
+ ConstTensor inputLayerNormTensor;
+ if (m_LayerNormParameters.m_InputLayerNormWeights != nullptr)
+ {
+ ConstTensor inputLayerNormTensorCopy(m_LayerNormParameters.m_InputLayerNormWeights->GetTensorInfo(),
+ m_LayerNormParameters.m_InputLayerNormWeights->Map(true));
+ inputLayerNormTensor = inputLayerNormTensorCopy;
+ inputParams.m_InputLayerNormWeights = &inputLayerNormTensor;
+ }
+ ConstTensor forgetLayerNormTensor;
+ if (m_LayerNormParameters.m_ForgetLayerNormWeights != nullptr)
+ {
+ ConstTensor forgetLayerNormTensorCopy(m_LayerNormParameters.m_ForgetLayerNormWeights->GetTensorInfo(),
+ m_LayerNormParameters.m_ForgetLayerNormWeights->Map(true));
+ forgetLayerNormTensor = forgetLayerNormTensorCopy;
+ inputParams.m_ForgetLayerNormWeights = &forgetLayerNormTensor;
+ }
+ ConstTensor cellLayerNormTensor;
+ if (m_LayerNormParameters.m_CellLayerNormWeights != nullptr)
+ {
+ ConstTensor cellLayerNormTensorCopy(m_LayerNormParameters.m_CellLayerNormWeights->GetTensorInfo(),
+ m_LayerNormParameters.m_CellLayerNormWeights->Map(true));
+ cellLayerNormTensor = cellLayerNormTensorCopy;
+ inputParams.m_CellLayerNormWeights = &cellLayerNormTensor;
+ }
+ ConstTensor outputLayerNormTensor;
+ if (m_LayerNormParameters.m_OutputLayerNormWeights != nullptr)
+ {
+ ConstTensor outputLayerNormTensorCopy(m_LayerNormParameters.m_OutputLayerNormWeights->GetTensorInfo(),
+ m_LayerNormParameters.m_OutputLayerNormWeights->Map(true));
+ outputLayerNormTensor = outputLayerNormTensorCopy;
+ inputParams.m_OutputLayerNormWeights = &outputLayerNormTensor;
+ }
+
visitor.VisitLstmLayer(this, GetParameters(), inputParams, GetName());
}
diff --git a/src/armnn/layers/LstmLayer.hpp b/src/armnn/layers/LstmLayer.hpp
index bfea5d8232..584d8e2547 100644
--- a/src/armnn/layers/LstmLayer.hpp
+++ b/src/armnn/layers/LstmLayer.hpp
@@ -11,6 +11,18 @@ namespace armnn
class ScopedCpuTensorHandle;
+struct LstmOptLayerNormParameters
+{
+ /// A unique pointer to represent 1D weights tensor with dimensions [num_units].
+ std::unique_ptr<ScopedCpuTensorHandle> m_InputLayerNormWeights;
+ /// A unique pointer to represent 1D weights tensor with dimensions [num_units].
+ std::unique_ptr<ScopedCpuTensorHandle> m_ForgetLayerNormWeights;
+ /// A unique pointer to represent 1D weights tensor with dimensions [num_units].
+ std::unique_ptr<ScopedCpuTensorHandle> m_CellLayerNormWeights;
+ /// A unique pointer to represent 1D weights tensor with dimensions [num_units].
+ std::unique_ptr<ScopedCpuTensorHandle> m_OutputLayerNormWeights;
+};
+
struct LstmOptCifgParameters
{
/// A unique pointer to represent 2D weights tensor with dimensions [input_size, num_units].
@@ -70,6 +82,7 @@ public:
LstmOptCifgParameters m_CifgParameters;
LstmOptProjectionParameters m_ProjectionParameters;
LstmOptPeepholeParameters m_PeepholeParameters;
+ LstmOptLayerNormParameters m_LayerNormParameters;
/// Makes a workload for the LSTM type.
/// @param [in] graph The graph where this layer can be found.