ArmNN
 21.02
QuantizedLstmParams.hpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
7 #include "TensorFwd.hpp"
8 #include "Exceptions.hpp"
9 
10 namespace armnn
11 {
12 
14 {
16  : m_InputToInputWeights(nullptr)
17  , m_InputToForgetWeights(nullptr)
18  , m_InputToCellWeights(nullptr)
19  , m_InputToOutputWeights(nullptr)
20 
21  , m_RecurrentToInputWeights(nullptr)
23  , m_RecurrentToCellWeights(nullptr)
25 
26  , m_InputGateBias(nullptr)
27  , m_ForgetGateBias(nullptr)
28  , m_CellBias(nullptr)
29  , m_OutputGateBias(nullptr)
30  {
31  }
32 
37 
42 
47 
48  const ConstTensor& Deref(const ConstTensor* tensorPtr) const
49  {
50  if (tensorPtr != nullptr)
51  {
52  const ConstTensor &temp = *tensorPtr;
53  return temp;
54  }
55  throw InvalidArgumentException("QuantizedLstmInputParams: Can't dereference a null pointer");
56  }
57 
59  {
60  return Deref(m_InputToInputWeights);
61  }
62 
64  {
65  return Deref(m_InputToForgetWeights);
66  }
67 
69  {
70  return Deref(m_InputToCellWeights);
71  }
72 
74  {
75  return Deref(m_InputToOutputWeights);
76  }
77 
79  {
80  return Deref(m_RecurrentToInputWeights);
81  }
82 
84  {
85  return Deref(m_RecurrentToForgetWeights);
86  }
87 
89  {
90  return Deref(m_RecurrentToCellWeights);
91  }
92 
94  {
95  return Deref(m_RecurrentToOutputWeights);
96  }
97 
99  {
100  return Deref(m_InputGateBias);
101  }
102 
104  {
105  return Deref(m_ForgetGateBias);
106  }
107 
108  const ConstTensor& GetCellBias() const
109  {
110  return Deref(m_CellBias);
111  }
112 
114  {
115  return Deref(m_OutputGateBias);
116  }
117 };
118 
120 {
122  : m_InputToInputWeights(nullptr)
123  , m_InputToForgetWeights(nullptr)
124  , m_InputToCellWeights(nullptr)
125  , m_InputToOutputWeights(nullptr)
126 
127  , m_RecurrentToInputWeights(nullptr)
128  , m_RecurrentToForgetWeights(nullptr)
129  , m_RecurrentToCellWeights(nullptr)
130  , m_RecurrentToOutputWeights(nullptr)
131 
132  , m_InputGateBias(nullptr)
133  , m_ForgetGateBias(nullptr)
134  , m_CellBias(nullptr)
135  , m_OutputGateBias(nullptr)
136  {
137  }
138 
143 
148 
153 
154 
155  const TensorInfo& Deref(const TensorInfo* tensorInfo) const
156  {
157  if (tensorInfo != nullptr)
158  {
159  const TensorInfo &temp = *tensorInfo;
160  return temp;
161  }
162  throw InvalidArgumentException("Can't dereference a null pointer");
163  }
164 
166  {
167  return Deref(m_InputToInputWeights);
168  }
170  {
171  return Deref(m_InputToForgetWeights);
172  }
174  {
175  return Deref(m_InputToCellWeights);
176  }
178  {
179  return Deref(m_InputToOutputWeights);
180  }
181 
183  {
184  return Deref(m_RecurrentToInputWeights);
185  }
187  {
188  return Deref(m_RecurrentToForgetWeights);
189  }
191  {
192  return Deref(m_RecurrentToCellWeights);
193  }
195  {
196  return Deref(m_RecurrentToOutputWeights);
197  }
198 
200  {
201  return Deref(m_InputGateBias);
202  }
204  {
205  return Deref(m_ForgetGateBias);
206  }
207  const TensorInfo& GetCellBias() const
208  {
209  return Deref(m_CellBias);
210  }
212  {
213  return Deref(m_OutputGateBias);
214  }
215 };
216 
217 } // namespace armnn
218 
const ConstTensor & Deref(const ConstTensor *tensorPtr) const
const ConstTensor & GetRecurrentToOutputWeights() const
const ConstTensor * m_RecurrentToOutputWeights
const ConstTensor & GetRecurrentToForgetWeights() const
const TensorInfo & GetRecurrentToCellWeights() const
const ConstTensor * m_RecurrentToForgetWeights
const ConstTensor & GetCellBias() const
const TensorInfo & GetInputToOutputWeights() const
const TensorInfo & GetInputToInputWeights() const
Copyright (c) 2021 ARM Limited and Contributors.
const ConstTensor & GetInputToOutputWeights() const
const TensorInfo & Deref(const TensorInfo *tensorInfo) const
const ConstTensor & GetInputToCellWeights() const
const ConstTensor * m_InputToForgetWeights
A tensor defined by a TensorInfo (shape and data type) and an immutable backing store.
Definition: Tensor.hpp:314
const TensorInfo & GetRecurrentToForgetWeights() const
const ConstTensor & GetInputToInputWeights() const
const TensorInfo & GetInputToCellWeights() const
const TensorInfo & GetForgetGateBias() const
const ConstTensor & GetForgetGateBias() const
const ConstTensor * m_RecurrentToInputWeights
const TensorInfo & GetInputGateBias() const
const TensorInfo & GetInputToForgetWeights() const
const ConstTensor & GetInputGateBias() const
const TensorInfo & GetCellBias() const
const TensorInfo & GetOutputGateBias() const
const ConstTensor & GetRecurrentToCellWeights() const
const TensorInfo & GetRecurrentToInputWeights() const
const ConstTensor & GetInputToForgetWeights() const
const ConstTensor * m_RecurrentToCellWeights
const ConstTensor * m_InputToOutputWeights
const ConstTensor & GetRecurrentToInputWeights() const
const TensorInfo & GetRecurrentToOutputWeights() const
const ConstTensor & GetOutputGateBias() const