ArmNN
 22.11
Lstm.hpp
Go to the documentation of this file.
1 //
2 // Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include <armnn/TypesUtils.hpp>
10 
11 #include "Encoders.hpp"
12 #include "Decoders.hpp"
13 
14 namespace armnn
15 {
16 
17 void LstmImpl(const LstmDescriptor& descriptor,
18  const TensorInfo& inputInfo,
19  const TensorInfo& outputInfo,
20  const TensorShape& inputToOutputWeightsShape,
21  const TensorShape& recurrentToOutputWeightsShape,
22  std::unique_ptr<Decoder<float>>& inputData,
23  std::unique_ptr<Decoder<float>>& outputStateIn,
24  std::unique_ptr<Decoder<float>>& cellStateIn,
25  std::unique_ptr<Encoder<float>>& outputStateOut,
26  std::unique_ptr<Encoder<float>>& cellStateOut,
27  std::unique_ptr<Encoder<float>>& output,
28  std::unique_ptr<Decoder<float>>& cellStateOutDecoder,
29  std::unique_ptr<Decoder<float>>& outputDecoder,
30  std::unique_ptr<Decoder<float>>& inputToInputWeightsTensor,
31  std::unique_ptr<Decoder<float>>& inputToForgetWeightsTensor,
32  std::unique_ptr<Decoder<float>>& inputToCellWeightsTensor,
33  std::unique_ptr<Decoder<float>>& inputToOutputWeightsTensor,
34  std::unique_ptr<Decoder<float>>& recurrentToInputWeightsTensor,
35  std::unique_ptr<Decoder<float>>& recurrentToForgetWeightsTensor,
36  std::unique_ptr<Decoder<float>>& recurrentToCellWeightsTensor,
37  std::unique_ptr<Decoder<float>>& recurrentToOutputWeightsTensor,
38  std::unique_ptr<Decoder<float>>& cellToInputWeightsTensor,
39  std::unique_ptr<Decoder<float>>& cellToForgetWeightsTensor,
40  std::unique_ptr<Decoder<float>>& cellToOutputWeightsTensor,
41  std::unique_ptr<Decoder<float>>& inputGateBiasTensor,
42  std::unique_ptr<Decoder<float>>& forgetGateBiasTensor,
43  std::unique_ptr<Decoder<float>>& cellBiasTensor,
44  std::unique_ptr<Decoder<float>>& outputGateBiasTensor,
45  std::unique_ptr<Decoder<float>>& projectionWeightsTensor,
46  std::unique_ptr<Decoder<float>>& projectionBiasTensor,
47  std::unique_ptr<Decoder<float>>& inputLayerNormWeights,
48  std::unique_ptr<Decoder<float>>& forgetLayerNormWeights,
49  std::unique_ptr<Decoder<float>>& cellLayerNormWeights,
50  std::unique_ptr<Decoder<float>>& outputLayerNormWeights,
51  std::unique_ptr<Encoder<float>>& inputGateScratch,
52  std::unique_ptr<Encoder<float>>& cellScratch,
53  std::unique_ptr<Encoder<float>>& forgetGateScratch,
54  std::unique_ptr<Encoder<float>>& outputGateScratch,
55  std::unique_ptr<Decoder<float>>& inputGateScratchDecoder,
56  std::unique_ptr<Decoder<float>>& cellScratchDecoder,
57  std::unique_ptr<Decoder<float>>& forgetGateScratchDecoder,
58  std::unique_ptr<Decoder<float>>& outputGateScratchDecoder,
59  float layerNormEpsilon);
60 
61 } //namespace armnn
void LstmImpl(const LstmDescriptor &descriptor, const TensorInfo &inputInfo, const TensorInfo &outputInfo, const TensorShape &inputToOutputWeightsShape, const TensorShape &recurrentToOutputWeightsShape, std::unique_ptr< Decoder< float >> &inputData, std::unique_ptr< Decoder< float >> &outputStateIn, std::unique_ptr< Decoder< float >> &cellStateIn, std::unique_ptr< Encoder< float >> &outputStateOut, std::unique_ptr< Encoder< float >> &cellStateOut, std::unique_ptr< Encoder< float >> &output, std::unique_ptr< Decoder< float >> &cellStateOutDecoder, std::unique_ptr< Decoder< float >> &outputDecoder, std::unique_ptr< Decoder< float >> &inputToInputWeightsTensor, std::unique_ptr< Decoder< float >> &inputToForgetWeightsTensor, std::unique_ptr< Decoder< float >> &inputToCellWeightsTensor, std::unique_ptr< Decoder< float >> &inputToOutputWeightsTensor, std::unique_ptr< Decoder< float >> &recurrentToInputWeightsTensor, std::unique_ptr< Decoder< float >> &recurrentToForgetWeightsTensor, std::unique_ptr< Decoder< float >> &recurrentToCellWeightsTensor, std::unique_ptr< Decoder< float >> &recurrentToOutputWeightsTensor, std::unique_ptr< Decoder< float >> &cellToInputWeightsTensor, std::unique_ptr< Decoder< float >> &cellToForgetWeightsTensor, std::unique_ptr< Decoder< float >> &cellToOutputWeightsTensor, std::unique_ptr< Decoder< float >> &inputGateBiasTensor, std::unique_ptr< Decoder< float >> &forgetGateBiasTensor, std::unique_ptr< Decoder< float >> &cellBiasTensor, std::unique_ptr< Decoder< float >> &outputGateBiasTensor, std::unique_ptr< Decoder< float >> &projectionWeightsTensor, std::unique_ptr< Decoder< float >> &projectionBiasTensor, std::unique_ptr< Decoder< float >> &inputLayerNormWeights, std::unique_ptr< Decoder< float >> &forgetLayerNormWeights, std::unique_ptr< Decoder< float >> &cellLayerNormWeights, std::unique_ptr< Decoder< float >> &outputLayerNormWeights, std::unique_ptr< Encoder< float >> &inputGateScratch, std::unique_ptr< Encoder< float >> &cellScratch, std::unique_ptr< Encoder< float >> &forgetGateScratch, std::unique_ptr< Encoder< float >> &outputGateScratch, std::unique_ptr< Decoder< float >> &inputGateScratchDecoder, std::unique_ptr< Decoder< float >> &cellScratchDecoder, std::unique_ptr< Decoder< float >> &forgetGateScratchDecoder, std::unique_ptr< Decoder< float >> &outputGateScratchDecoder, float layerNormEpsilon)
Definition: Lstm.cpp:13
Copyright (c) 2021 ARM Limited and Contributors.