From a65b7aeafc0ef6acf40e4a8a6d36206bf53d717c Mon Sep 17 00:00:00 2001 From: Matteo Martincigh Date: Wed, 14 Nov 2018 12:39:55 +0000 Subject: IVGCVSW-2092 Port LSTMCell::Eval to ArmNN * Ported Google's LSTM implementation to RefLstmFloat32Workload * Fixed the code throughout because of an error in the docs around the scratch buffer size * Updated IsLstmSupported * Added the unit tests !android-nn-driver:127 Change-Id: I5577b7e39ca52df1a7f102a9b437df6aa99520b6 --- .../reference/workloads/RefLstmFloat32Workload.hpp | 24 +++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) (limited to 'src/backends/reference/workloads/RefLstmFloat32Workload.hpp') diff --git a/src/backends/reference/workloads/RefLstmFloat32Workload.hpp b/src/backends/reference/workloads/RefLstmFloat32Workload.hpp index 1f634d3ca1..a2dead8b9c 100644 --- a/src/backends/reference/workloads/RefLstmFloat32Workload.hpp +++ b/src/backends/reference/workloads/RefLstmFloat32Workload.hpp @@ -5,6 +5,8 @@ #pragma once +#include + #include #include @@ -14,8 +16,28 @@ namespace armnn class RefLstmFloat32Workload : public Float32Workload { public: - using Float32Workload::Float32Workload; + explicit RefLstmFloat32Workload(const LstmQueueDescriptor& descriptor, const WorkloadInfo& info); + virtual void Execute() const override; + +private: + std::unique_ptr m_InputToInputWeightsTensor; + std::unique_ptr m_InputToForgetWeightsTensor; + std::unique_ptr m_InputToCellWeightsTensor; + std::unique_ptr m_InputToOutputWeightsTensor; + std::unique_ptr m_RecurrentToInputWeightsTensor; + std::unique_ptr m_RecurrentToForgetWeightsTensor; + std::unique_ptr m_RecurrentToCellWeightsTensor; + std::unique_ptr m_RecurrentToOutputWeightsTensor; + std::unique_ptr m_CellToInputWeightsTensor; + std::unique_ptr m_CellToForgetWeightsTensor; + std::unique_ptr m_CellToOutputWeightsTensor; + std::unique_ptr m_InputGateBiasTensor; + std::unique_ptr m_ForgetGateBiasTensor; + std::unique_ptr m_CellBiasTensor; + std::unique_ptr m_OutputGateBiasTensor; + std::unique_ptr m_ProjectionWeightsTensor; + std::unique_ptr m_ProjectionBiasTensor; }; } //namespace armnn -- cgit v1.2.1