ArmNN
 21.02
ClLstmFloatWorkload.hpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include <armnn/Descriptors.hpp>
9 #include <armnn/LstmParams.hpp>
12 
13 #include <arm_compute/runtime/CL/functions/CLLSTMLayer.h>
14 
15 namespace armnn
16 {
17 
18 class ClLstmFloatWorkload : public FloatWorkload<LstmQueueDescriptor>
19 {
20 public:
21  ClLstmFloatWorkload(const LstmQueueDescriptor& descriptor,
22  const WorkloadInfo& info,
23  const arm_compute::CLCompileContext& clCompileContext);
24  void Execute() const override;
25 
26 private:
27  mutable arm_compute::CLLSTMLayer m_LstmLayer;
28 
29  std::unique_ptr<arm_compute::CLTensor> m_InputToInputWeightsTensor;
30  std::unique_ptr<arm_compute::CLTensor> m_InputToForgetWeightsTensor;
31  std::unique_ptr<arm_compute::CLTensor> m_InputToCellWeightsTensor;
32  std::unique_ptr<arm_compute::CLTensor> m_InputToOutputWeightsTensor;
33  std::unique_ptr<arm_compute::CLTensor> m_RecurrentToInputWeightsTensor;
34  std::unique_ptr<arm_compute::CLTensor> m_RecurrentToForgetWeightsTensor;
35  std::unique_ptr<arm_compute::CLTensor> m_RecurrentToCellWeightsTensor;
36  std::unique_ptr<arm_compute::CLTensor> m_RecurrentToOutputWeightsTensor;
37  std::unique_ptr<arm_compute::CLTensor> m_CellToInputWeightsTensor;
38  std::unique_ptr<arm_compute::CLTensor> m_CellToForgetWeightsTensor;
39  std::unique_ptr<arm_compute::CLTensor> m_CellToOutputWeightsTensor;
40  std::unique_ptr<arm_compute::CLTensor> m_InputGateBiasTensor;
41  std::unique_ptr<arm_compute::CLTensor> m_ForgetGateBiasTensor;
42  std::unique_ptr<arm_compute::CLTensor> m_CellBiasTensor;
43  std::unique_ptr<arm_compute::CLTensor> m_OutputGateBiasTensor;
44  std::unique_ptr<arm_compute::CLTensor> m_ProjectionWeightsTensor;
45  std::unique_ptr<arm_compute::CLTensor> m_ProjectionBiasTensor;
46  std::unique_ptr<arm_compute::CLTensor> m_InputLayerNormWeightsTensor;
47  std::unique_ptr<arm_compute::CLTensor> m_ForgetLayerNormWeightsTensor;
48  std::unique_ptr<arm_compute::CLTensor> m_CellLayerNormWeightsTensor;
49  std::unique_ptr<arm_compute::CLTensor> m_OutputLayerNormWeightsTensor;
50 
51  std::unique_ptr<arm_compute::CLTensor> m_ScratchBuffer;
52 
53  void FreeUnusedTensors();
54 };
55 
56 arm_compute::Status ClLstmFloatWorkloadValidate(const TensorInfo& input, const TensorInfo& outputStateIn,
57  const TensorInfo& cellStateIn, const TensorInfo& scratchBuffer,
58  const TensorInfo& outputStateOut, const TensorInfo& cellStateOut,
59  const TensorInfo& output, const LstmDescriptor &descriptor,
60  const LstmInputParamsInfo& paramsInfo);
61 } //namespace armnn
ClLstmFloatWorkload(const LstmQueueDescriptor &descriptor, const WorkloadInfo &info, const arm_compute::CLCompileContext &clCompileContext)
arm_compute::Status ClLstmFloatWorkloadValidate(const TensorInfo &input, const TensorInfo &outputStateIn, const TensorInfo &cellStateIn, const TensorInfo &scratchBuffer, const TensorInfo &outputStateOut, const TensorInfo &cellStateOut, const TensorInfo &output, const LstmDescriptor &descriptor, const LstmInputParamsInfo &paramsInfo)
Copyright (c) 2021 ARM Limited and Contributors.
An LstmDescriptor for the LstmLayer.
Status
enumeration
Definition: Types.hpp:26
Contains information about inputs and outputs to a layer.