ArmNN
 20.11
ClLstmFloatWorkload.hpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include <armnn/Descriptors.hpp>
9 #include <armnn/LstmParams.hpp>
12 
13 #include <arm_compute/runtime/CL/functions/CLLSTMLayer.h>
14 
15 namespace armnn
16 {
17 
18 class ClLstmFloatWorkload : public FloatWorkload<LstmQueueDescriptor>
19 {
20 public:
21  ClLstmFloatWorkload(const LstmQueueDescriptor& descriptor, const WorkloadInfo& info);
22  void Execute() const override;
23 
24 private:
25  mutable arm_compute::CLLSTMLayer m_LstmLayer;
26 
27  std::unique_ptr<arm_compute::CLTensor> m_InputToInputWeightsTensor;
28  std::unique_ptr<arm_compute::CLTensor> m_InputToForgetWeightsTensor;
29  std::unique_ptr<arm_compute::CLTensor> m_InputToCellWeightsTensor;
30  std::unique_ptr<arm_compute::CLTensor> m_InputToOutputWeightsTensor;
31  std::unique_ptr<arm_compute::CLTensor> m_RecurrentToInputWeightsTensor;
32  std::unique_ptr<arm_compute::CLTensor> m_RecurrentToForgetWeightsTensor;
33  std::unique_ptr<arm_compute::CLTensor> m_RecurrentToCellWeightsTensor;
34  std::unique_ptr<arm_compute::CLTensor> m_RecurrentToOutputWeightsTensor;
35  std::unique_ptr<arm_compute::CLTensor> m_CellToInputWeightsTensor;
36  std::unique_ptr<arm_compute::CLTensor> m_CellToForgetWeightsTensor;
37  std::unique_ptr<arm_compute::CLTensor> m_CellToOutputWeightsTensor;
38  std::unique_ptr<arm_compute::CLTensor> m_InputGateBiasTensor;
39  std::unique_ptr<arm_compute::CLTensor> m_ForgetGateBiasTensor;
40  std::unique_ptr<arm_compute::CLTensor> m_CellBiasTensor;
41  std::unique_ptr<arm_compute::CLTensor> m_OutputGateBiasTensor;
42  std::unique_ptr<arm_compute::CLTensor> m_ProjectionWeightsTensor;
43  std::unique_ptr<arm_compute::CLTensor> m_ProjectionBiasTensor;
44  std::unique_ptr<arm_compute::CLTensor> m_InputLayerNormWeightsTensor;
45  std::unique_ptr<arm_compute::CLTensor> m_ForgetLayerNormWeightsTensor;
46  std::unique_ptr<arm_compute::CLTensor> m_CellLayerNormWeightsTensor;
47  std::unique_ptr<arm_compute::CLTensor> m_OutputLayerNormWeightsTensor;
48 
49  std::unique_ptr<arm_compute::CLTensor> m_ScratchBuffer;
50 
51  void FreeUnusedTensors();
52 };
53 
54 arm_compute::Status ClLstmFloatWorkloadValidate(const TensorInfo& input, const TensorInfo& outputStateIn,
55  const TensorInfo& cellStateIn, const TensorInfo& scratchBuffer,
56  const TensorInfo& outputStateOut, const TensorInfo& cellStateOut,
57  const TensorInfo& output, const LstmDescriptor &descriptor,
58  const LstmInputParamsInfo& paramsInfo);
59 } //namespace armnn
arm_compute::Status ClLstmFloatWorkloadValidate(const TensorInfo &input, const TensorInfo &outputStateIn, const TensorInfo &cellStateIn, const TensorInfo &scratchBuffer, const TensorInfo &outputStateOut, const TensorInfo &cellStateOut, const TensorInfo &output, const LstmDescriptor &descriptor, const LstmInputParamsInfo &paramsInfo)
ClLstmFloatWorkload(const LstmQueueDescriptor &descriptor, const WorkloadInfo &info)
Copyright (c) 2020 ARM Limited.
An LstmDescriptor for the LstmLayer.
Status
enumeration
Definition: Types.hpp:26
Contains information about inputs and outputs to a layer.