diff options
Diffstat (limited to 'src/armnn/layers/LstmLayer.hpp')
-rw-r--r-- | src/armnn/layers/LstmLayer.hpp | 13 |
1 files changed, 13 insertions, 0 deletions
diff --git a/src/armnn/layers/LstmLayer.hpp b/src/armnn/layers/LstmLayer.hpp index bfea5d8232..584d8e2547 100644 --- a/src/armnn/layers/LstmLayer.hpp +++ b/src/armnn/layers/LstmLayer.hpp @@ -11,6 +11,18 @@ namespace armnn class ScopedCpuTensorHandle; +struct LstmOptLayerNormParameters +{ + /// A unique pointer to represent 1D weights tensor with dimensions [num_units]. + std::unique_ptr<ScopedCpuTensorHandle> m_InputLayerNormWeights; + /// A unique pointer to represent 1D weights tensor with dimensions [num_units]. + std::unique_ptr<ScopedCpuTensorHandle> m_ForgetLayerNormWeights; + /// A unique pointer to represent 1D weights tensor with dimensions [num_units]. + std::unique_ptr<ScopedCpuTensorHandle> m_CellLayerNormWeights; + /// A unique pointer to represent 1D weights tensor with dimensions [num_units]. + std::unique_ptr<ScopedCpuTensorHandle> m_OutputLayerNormWeights; +}; + struct LstmOptCifgParameters { /// A unique pointer to represent 2D weights tensor with dimensions [input_size, num_units]. @@ -70,6 +82,7 @@ public: LstmOptCifgParameters m_CifgParameters; LstmOptProjectionParameters m_ProjectionParameters; LstmOptPeepholeParameters m_PeepholeParameters; + LstmOptLayerNormParameters m_LayerNormParameters; /// Makes a workload for the LSTM type. /// @param [in] graph The graph where this layer can be found. |