ArmNN
 20.02
LstmParams.hpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
7 #include "TensorFwd.hpp"
8 #include "Exceptions.hpp"
9 
10 namespace armnn
11 {
12 
14 {
16  : m_InputToInputWeights(nullptr)
17  , m_InputToForgetWeights(nullptr)
18  , m_InputToCellWeights(nullptr)
19  , m_InputToOutputWeights(nullptr)
20  , m_RecurrentToInputWeights(nullptr)
22  , m_RecurrentToCellWeights(nullptr)
24  , m_CellToInputWeights(nullptr)
25  , m_CellToForgetWeights(nullptr)
26  , m_CellToOutputWeights(nullptr)
27  , m_InputGateBias(nullptr)
28  , m_ForgetGateBias(nullptr)
29  , m_CellBias(nullptr)
30  , m_OutputGateBias(nullptr)
31  , m_ProjectionWeights(nullptr)
32  , m_ProjectionBias(nullptr)
33  , m_InputLayerNormWeights(nullptr)
34  , m_ForgetLayerNormWeights(nullptr)
35  , m_CellLayerNormWeights(nullptr)
36  , m_OutputLayerNormWeights(nullptr)
37  {
38  }
39 
61 };
62 
64 {
66  : m_InputToInputWeights(nullptr)
67  , m_InputToForgetWeights(nullptr)
68  , m_InputToCellWeights(nullptr)
69  , m_InputToOutputWeights(nullptr)
70  , m_RecurrentToInputWeights(nullptr)
72  , m_RecurrentToCellWeights(nullptr)
74  , m_CellToInputWeights(nullptr)
75  , m_CellToForgetWeights(nullptr)
76  , m_CellToOutputWeights(nullptr)
77  , m_InputGateBias(nullptr)
78  , m_ForgetGateBias(nullptr)
79  , m_CellBias(nullptr)
80  , m_OutputGateBias(nullptr)
81  , m_ProjectionWeights(nullptr)
82  , m_ProjectionBias(nullptr)
83  , m_InputLayerNormWeights(nullptr)
84  , m_ForgetLayerNormWeights(nullptr)
85  , m_CellLayerNormWeights(nullptr)
86  , m_OutputLayerNormWeights(nullptr)
87  {
88  }
110 
111  const TensorInfo& Deref(const TensorInfo* tensorInfo) const
112  {
113  if (tensorInfo != nullptr)
114  {
115  const TensorInfo &temp = *tensorInfo;
116  return temp;
117  }
118  throw InvalidArgumentException("Can't dereference a null pointer");
119  }
120 
122  {
123  return Deref(m_InputToInputWeights);
124  }
126  {
127  return Deref(m_InputToForgetWeights);
128  }
130  {
131  return Deref(m_InputToCellWeights);
132  }
134  {
135  return Deref(m_InputToOutputWeights);
136  }
138  {
139  return Deref(m_RecurrentToInputWeights);
140  }
142  {
143  return Deref(m_RecurrentToForgetWeights);
144  }
146  {
147  return Deref(m_RecurrentToCellWeights);
148  }
150  {
151  return Deref(m_RecurrentToOutputWeights);
152  }
154  {
155  return Deref(m_CellToInputWeights);
156  }
158  {
159  return Deref(m_CellToForgetWeights);
160  }
162  {
163  return Deref(m_CellToOutputWeights);
164  }
166  {
167  return Deref(m_InputGateBias);
168  }
170  {
171  return Deref(m_ForgetGateBias);
172  }
173  const TensorInfo& GetCellBias() const
174  {
175  return Deref(m_CellBias);
176  }
178  {
179  return Deref(m_OutputGateBias);
180  }
182  {
183  return Deref(m_ProjectionWeights);
184  }
186  {
187  return Deref(m_ProjectionBias);
188  }
190  {
191  return Deref(m_InputLayerNormWeights);
192  }
194  {
195  return Deref(m_ForgetLayerNormWeights);
196  }
198  {
199  return Deref(m_CellLayerNormWeights);
200  }
202  {
203  return Deref(m_OutputLayerNormWeights);
204  }
205 };
206 
207 } // namespace armnn
208 
const TensorInfo * m_InputLayerNormWeights
Definition: LstmParams.hpp:106
const ConstTensor * m_ProjectionWeights
Definition: LstmParams.hpp:55
const TensorInfo * m_OutputGateBias
Definition: LstmParams.hpp:103
const TensorInfo & GetRecurrentToCellWeights() const
Definition: LstmParams.hpp:145
const ConstTensor * m_CellBias
Definition: LstmParams.hpp:53
const TensorInfo * m_ProjectionWeights
Definition: LstmParams.hpp:104
const TensorInfo & GetCellBias() const
Definition: LstmParams.hpp:173
const TensorInfo & GetRecurrentToInputWeights() const
Definition: LstmParams.hpp:137
const TensorInfo & GetCellLayerNormWeights() const
Definition: LstmParams.hpp:197
const ConstTensor * m_CellToOutputWeights
Definition: LstmParams.hpp:50
const TensorInfo & GetRecurrentToOutputWeights() const
Definition: LstmParams.hpp:149
const ConstTensor * m_CellToInputWeights
Definition: LstmParams.hpp:48
const TensorInfo & GetCellToInputWeights() const
Definition: LstmParams.hpp:153
const TensorInfo * m_ForgetLayerNormWeights
Definition: LstmParams.hpp:107
const ConstTensor * m_InputGateBias
Definition: LstmParams.hpp:51
const ConstTensor * m_RecurrentToCellWeights
Definition: LstmParams.hpp:46
const ConstTensor * m_ForgetLayerNormWeights
Definition: LstmParams.hpp:58
const ConstTensor * m_CellToForgetWeights
Definition: LstmParams.hpp:49
Copyright (c) 2020 ARM Limited.
const TensorInfo & GetCellToForgetWeights() const
Definition: LstmParams.hpp:157
const TensorInfo * m_ForgetGateBias
Definition: LstmParams.hpp:101
const TensorInfo * m_OutputLayerNormWeights
Definition: LstmParams.hpp:109
const TensorInfo & GetForgetLayerNormWeights() const
Definition: LstmParams.hpp:193
const TensorInfo & Deref(const TensorInfo *tensorInfo) const
Definition: LstmParams.hpp:111
const TensorInfo * m_RecurrentToCellWeights
Definition: LstmParams.hpp:95
const ConstTensor * m_OutputGateBias
Definition: LstmParams.hpp:54
const TensorInfo & GetCellToOutputWeights() const
Definition: LstmParams.hpp:161
const TensorInfo * m_InputToCellWeights
Definition: LstmParams.hpp:91
const TensorInfo & GetInputToCellWeights() const
Definition: LstmParams.hpp:129
const TensorInfo * m_RecurrentToInputWeights
Definition: LstmParams.hpp:93
const TensorInfo * m_RecurrentToOutputWeights
Definition: LstmParams.hpp:96
const ConstTensor * m_InputLayerNormWeights
Definition: LstmParams.hpp:57
const TensorInfo * m_CellToOutputWeights
Definition: LstmParams.hpp:99
const ConstTensor * m_RecurrentToOutputWeights
Definition: LstmParams.hpp:47
const TensorInfo & GetInputToOutputWeights() const
Definition: LstmParams.hpp:133
const ConstTensor * m_ProjectionBias
Definition: LstmParams.hpp:56
const TensorInfo * m_CellLayerNormWeights
Definition: LstmParams.hpp:108
const TensorInfo * m_InputToForgetWeights
Definition: LstmParams.hpp:90
A tensor defined by a TensorInfo (shape and data type) and an immutable backing store.
Definition: Tensor.hpp:199
const TensorInfo * m_ProjectionBias
Definition: LstmParams.hpp:105
const TensorInfo * m_CellToInputWeights
Definition: LstmParams.hpp:97
const TensorInfo & GetRecurrentToForgetWeights() const
Definition: LstmParams.hpp:141
const TensorInfo * m_InputGateBias
Definition: LstmParams.hpp:100
const ConstTensor * m_CellLayerNormWeights
Definition: LstmParams.hpp:59
const ConstTensor * m_ForgetGateBias
Definition: LstmParams.hpp:52
const ConstTensor * m_InputToCellWeights
Definition: LstmParams.hpp:42
const TensorInfo & GetInputToInputWeights() const
Definition: LstmParams.hpp:121
const TensorInfo & GetOutputLayerNormWeights() const
Definition: LstmParams.hpp:201
const ConstTensor * m_InputToOutputWeights
Definition: LstmParams.hpp:43
const TensorInfo * m_InputToOutputWeights
Definition: LstmParams.hpp:92
const TensorInfo * m_InputToInputWeights
Definition: LstmParams.hpp:89
const ConstTensor * m_RecurrentToForgetWeights
Definition: LstmParams.hpp:45
const TensorInfo & GetForgetGateBias() const
Definition: LstmParams.hpp:169
const ConstTensor * m_RecurrentToInputWeights
Definition: LstmParams.hpp:44
const TensorInfo * m_CellToForgetWeights
Definition: LstmParams.hpp:98
const TensorInfo * m_RecurrentToForgetWeights
Definition: LstmParams.hpp:94
const TensorInfo & GetInputGateBias() const
Definition: LstmParams.hpp:165
const TensorInfo & GetProjectionWeights() const
Definition: LstmParams.hpp:181
const TensorInfo & GetInputToForgetWeights() const
Definition: LstmParams.hpp:125
const TensorInfo & GetInputLayerNormWeights() const
Definition: LstmParams.hpp:189
const ConstTensor * m_OutputLayerNormWeights
Definition: LstmParams.hpp:60
const TensorInfo & GetOutputGateBias() const
Definition: LstmParams.hpp:177
const TensorInfo * m_CellBias
Definition: LstmParams.hpp:102
const TensorInfo & GetProjectionBias() const
Definition: LstmParams.hpp:185
const ConstTensor * m_InputToForgetWeights
Definition: LstmParams.hpp:41
const ConstTensor * m_InputToInputWeights
Definition: LstmParams.hpp:40