1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
|
//
// Copyright © 2017 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//
#pragma once
#include "TensorFwd.hpp"
namespace armnn
{
struct LstmInputParams
{
LstmInputParams()
: m_InputToInputWeights(nullptr)
, m_InputToForgetWeights(nullptr)
, m_InputToCellWeights(nullptr)
, m_InputToOutputWeights(nullptr)
, m_RecurrentToInputWeights(nullptr)
, m_RecurrentToForgetWeights(nullptr)
, m_RecurrentToCellWeights(nullptr)
, m_RecurrentToOutputWeights(nullptr)
, m_CellToInputWeights(nullptr)
, m_CellToForgetWeights(nullptr)
, m_CellToOutputWeights(nullptr)
, m_InputGateBias(nullptr)
, m_ForgetGateBias(nullptr)
, m_CellBias(nullptr)
, m_OutputGateBias(nullptr)
, m_ProjectionWeights(nullptr)
, m_ProjectionBias(nullptr)
, m_InputLayerNormWeights(nullptr)
, m_ForgetLayerNormWeights(nullptr)
, m_CellLayerNormWeights(nullptr)
, m_OutputLayerNormWeights(nullptr)
{
}
const ConstTensor* m_InputToInputWeights;
const ConstTensor* m_InputToForgetWeights;
const ConstTensor* m_InputToCellWeights;
const ConstTensor* m_InputToOutputWeights;
const ConstTensor* m_RecurrentToInputWeights;
const ConstTensor* m_RecurrentToForgetWeights;
const ConstTensor* m_RecurrentToCellWeights;
const ConstTensor* m_RecurrentToOutputWeights;
const ConstTensor* m_CellToInputWeights;
const ConstTensor* m_CellToForgetWeights;
const ConstTensor* m_CellToOutputWeights;
const ConstTensor* m_InputGateBias;
const ConstTensor* m_ForgetGateBias;
const ConstTensor* m_CellBias;
const ConstTensor* m_OutputGateBias;
const ConstTensor* m_ProjectionWeights;
const ConstTensor* m_ProjectionBias;
const ConstTensor* m_InputLayerNormWeights;
const ConstTensor* m_ForgetLayerNormWeights;
const ConstTensor* m_CellLayerNormWeights;
const ConstTensor* m_OutputLayerNormWeights;
};
} // namespace armnn
|