ArmNN  NotReleased
QuantizedLstmLayer.hpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
7 #include <Layer.hpp>
8 
9 namespace armnn
10 {
11 
12 class ScopedCpuTensorHandle;
13 
15 {
17  std::unique_ptr<ScopedCpuTensorHandle> m_InputToInputWeights;
19  std::unique_ptr<ScopedCpuTensorHandle> m_InputToForgetWeights;
21  std::unique_ptr<ScopedCpuTensorHandle> m_InputToCellWeights;
23  std::unique_ptr<ScopedCpuTensorHandle> m_InputToOutputWeights;
24 
26  std::unique_ptr<ScopedCpuTensorHandle> m_RecurrentToInputWeights;
28  std::unique_ptr<ScopedCpuTensorHandle> m_RecurrentToForgetWeights;
30  std::unique_ptr<ScopedCpuTensorHandle> m_RecurrentToCellWeights;
32  std::unique_ptr<ScopedCpuTensorHandle> m_RecurrentToOutputWeights;
33 
35  std::unique_ptr<ScopedCpuTensorHandle> m_InputGateBias;
37  std::unique_ptr<ScopedCpuTensorHandle> m_ForgetGateBias;
39  std::unique_ptr<ScopedCpuTensorHandle> m_CellBias;
41  std::unique_ptr<ScopedCpuTensorHandle> m_OutputGateBias;
42 };
43 
45 class QuantizedLstmLayer : public Layer
46 {
47 public:
48 
50 
55  virtual std::unique_ptr<IWorkload> CreateWorkload(const IWorkloadFactory& factory) const override;
56 
59  QuantizedLstmLayer* Clone(Graph& graph) const override;
60 
63  void ValidateTensorShapesFromInputs() override;
64 
69  std::vector<TensorShape> InferOutputShapes(const std::vector<TensorShape>& inputShapes) const override;
70 
71  void Accept(ILayerVisitor& visitor) const override;
72 
73 protected:
76  QuantizedLstmLayer(const char* name);
77 
79  ~QuantizedLstmLayer() = default;
80 
83  Layer::ConstantTensors GetConstantTensorsByRef() override;
84 };
85 
86 } // namespace armnn
std::unique_ptr< ScopedCpuTensorHandle > m_ForgetGateBias
A unique pointer to represent 1D bias tensor with dimensions [outputSize] (int32).
std::unique_ptr< ScopedCpuTensorHandle > m_InputGateBias
A unique pointer to represent 1D bias tensor with dimensions [outputSize] (int32).
std::unique_ptr< ScopedCpuTensorHandle > m_RecurrentToCellWeights
A unique pointer to represent 2D weights tensor with dimensions [outputSize, outputSize] (QAsymm8)...
std::unique_ptr< ScopedCpuTensorHandle > m_RecurrentToOutputWeights
A unique pointer to represent 2D weights tensor with dimensions [outputSize, outputSize] (QAsymm8)...
std::unique_ptr< ScopedCpuTensorHandle > m_RecurrentToInputWeights
A unique pointer to represent 2D weights tensor with dimensions [outputSize, outputSize] (QAsymm8)...
std::vector< std::reference_wrapper< std::unique_ptr< ScopedCpuTensorHandle > >> ConstantTensors
Definition: Layer.hpp:356
armnnUtils::Sockets::Socket Accept(Socket s, sockaddr *addr, socklen_t *addrlen, int flags)
std::unique_ptr< ScopedCpuTensorHandle > m_InputToInputWeights
A unique pointer to represent 2D weights tensor with dimensions [outputSize, inputSize] (QAsymm8)...
std::unique_ptr< ScopedCpuTensorHandle > m_OutputGateBias
A unique pointer to represent 1D bias tensor with dimensions [outputSize] (int32).
std::unique_ptr< ScopedCpuTensorHandle > m_RecurrentToForgetWeights
A unique pointer to represent 2D weights tensor with dimensions [outputSize, outputSize] (QAsymm8)...
std::unique_ptr< ScopedCpuTensorHandle > m_InputToCellWeights
A unique pointer to represent 2D weights tensor with dimensions [outputSize, inputSize] (QAsymm8)...
QuantizedLstmParameters m_QuantizedLstmParameters
std::unique_ptr< ScopedCpuTensorHandle > m_InputToOutputWeights
A unique pointer to represent 2D weights tensor with dimensions [outputSize, inputSize] (QAsymm8)...
std::unique_ptr< armnn::IWorkload > CreateWorkload(const armnn::IWorkloadFactory &workloadFactory, const armnn::WorkloadInfo &info, const DescriptorType &descriptor)
std::unique_ptr< ScopedCpuTensorHandle > m_InputToForgetWeights
A unique pointer to represent 2D weights tensor with dimensions [outputSize, inputSize] (QAsymm8)...
std::unique_ptr< ScopedCpuTensorHandle > m_CellBias
A unique pointer to represent 1D bias tensor with dimensions [outputSize] (int32).
This layer represents a QuantizedLstm operation.