aboutsummaryrefslogtreecommitdiff
path: root/include/armnn/LstmParams.hpp
diff options
context:
space:
mode:
authortelsoa01 <telmo.soares@arm.com>2018-08-31 09:22:23 +0100
committertelsoa01 <telmo.soares@arm.com>2018-08-31 09:22:23 +0100
commitc577f2c6a3b4ddb6ba87a882723c53a248afbeba (patch)
treebd7d4c148df27f8be6649d313efb24f536b7cf34 /include/armnn/LstmParams.hpp
parent4c7098bfeab1ffe1cdc77f6c15548d3e73274746 (diff)
downloadarmnn-c577f2c6a3b4ddb6ba87a882723c53a248afbeba.tar.gz
Release 18.08
Diffstat (limited to 'include/armnn/LstmParams.hpp')
-rw-r--r--include/armnn/LstmParams.hpp55
1 files changed, 55 insertions, 0 deletions
diff --git a/include/armnn/LstmParams.hpp b/include/armnn/LstmParams.hpp
new file mode 100644
index 0000000000..cfca0df5bb
--- /dev/null
+++ b/include/armnn/LstmParams.hpp
@@ -0,0 +1,55 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// See LICENSE file in the project root for full license information.
+//
+#pragma once
+
+#include "TensorFwd.hpp"
+
+namespace armnn
+{
+
+struct LstmInputParams
+{
+ LstmInputParams()
+ : m_InputToInputWeights(nullptr)
+ , m_InputToForgetWeights(nullptr)
+ , m_InputToCellWeights(nullptr)
+ , m_InputToOutputWeights(nullptr)
+ , m_RecurrentToInputWeights(nullptr)
+ , m_RecurrentToForgetWeights(nullptr)
+ , m_RecurrentToCellWeights(nullptr)
+ , m_RecurrentToOutputWeights(nullptr)
+ , m_CellToInputWeights(nullptr)
+ , m_CellToForgetWeights(nullptr)
+ , m_CellToOutputWeights(nullptr)
+ , m_InputGateBias(nullptr)
+ , m_ForgetGateBias(nullptr)
+ , m_CellBias(nullptr)
+ , m_OutputGateBias(nullptr)
+ , m_ProjectionWeights(nullptr)
+ , m_ProjectionBias(nullptr)
+ {
+ }
+
+ const ConstTensor* m_InputToInputWeights;
+ const ConstTensor* m_InputToForgetWeights;
+ const ConstTensor* m_InputToCellWeights;
+ const ConstTensor* m_InputToOutputWeights;
+ const ConstTensor* m_RecurrentToInputWeights;
+ const ConstTensor* m_RecurrentToForgetWeights;
+ const ConstTensor* m_RecurrentToCellWeights;
+ const ConstTensor* m_RecurrentToOutputWeights;
+ const ConstTensor* m_CellToInputWeights;
+ const ConstTensor* m_CellToForgetWeights;
+ const ConstTensor* m_CellToOutputWeights;
+ const ConstTensor* m_InputGateBias;
+ const ConstTensor* m_ForgetGateBias;
+ const ConstTensor* m_CellBias;
+ const ConstTensor* m_OutputGateBias;
+ const ConstTensor* m_ProjectionWeights;
+ const ConstTensor* m_ProjectionBias;
+};
+
+} // namespace armnn
+