diff options
Diffstat (limited to 'include')
-rw-r--r-- | include/armnn/Descriptors.hpp | 60 | ||||
-rw-r--r-- | include/armnn/DescriptorsFwd.hpp | 1 | ||||
-rw-r--r-- | include/armnn/ILayerSupport.hpp | 10 | ||||
-rw-r--r-- | include/armnn/ILayerVisitor.hpp | 10 | ||||
-rw-r--r-- | include/armnn/INetwork.hpp | 9 | ||||
-rw-r--r-- | include/armnn/LayerVisitorBase.hpp | 5 |
6 files changed, 95 insertions, 0 deletions
diff --git a/include/armnn/Descriptors.hpp b/include/armnn/Descriptors.hpp index 57917261d4..95eeaaa420 100644 --- a/include/armnn/Descriptors.hpp +++ b/include/armnn/Descriptors.hpp @@ -1083,6 +1083,66 @@ struct PreCompiledDescriptor unsigned int m_NumOutputSlots; }; +/// A QLstmDescriptor for the QLstmLayer. +struct QLstmDescriptor +{ + QLstmDescriptor() + : m_CellClip(0.0) + , m_ProjectionClip(0.0) + , m_CifgEnabled(true) + , m_PeepholeEnabled(false) + , m_ProjectionEnabled(false) + , m_LayerNormEnabled(false) + , m_InputIntermediateScale(0.0) + , m_ForgetIntermediateScale(0.0) + , m_CellIntermediateScale(0.0) + , m_OutputIntermediateScale(0.0) + , m_HiddenStateZeroPoint(0) + , m_HiddenStateScale(0.0) + {} + + bool operator ==(const QLstmDescriptor& rhs) const + { + return m_CellClip == rhs.m_CellClip && + m_ProjectionClip == rhs.m_ProjectionClip && + m_CifgEnabled == rhs.m_CifgEnabled && + m_PeepholeEnabled == rhs.m_PeepholeEnabled && + m_ProjectionEnabled == rhs.m_ProjectionEnabled && + m_LayerNormEnabled == rhs.m_LayerNormEnabled && + m_InputIntermediateScale == rhs.m_InputIntermediateScale && + m_ForgetIntermediateScale == rhs.m_ForgetIntermediateScale && + m_CellIntermediateScale == rhs.m_CellIntermediateScale && + m_OutputIntermediateScale == rhs.m_OutputIntermediateScale && + m_HiddenStateZeroPoint == rhs.m_HiddenStateZeroPoint && + m_HiddenStateScale == rhs.m_HiddenStateScale; + } + + /// Clipping threshold value for the cell state + float m_CellClip; + /// Clipping threshold value for the projection + float m_ProjectionClip; + /// Enable/disable CIFG (coupled input & forget gate). + bool m_CifgEnabled; + /// Enable/disable peephole + bool m_PeepholeEnabled; + /// Enable/disable the projection layer + bool m_ProjectionEnabled; + /// Enable/disable layer normalization + bool m_LayerNormEnabled; + /// Input intermediate quantization scale + float m_InputIntermediateScale; + /// Forget intermediate quantization scale + float m_ForgetIntermediateScale; + /// Cell intermediate quantization scale + float m_CellIntermediateScale; + /// Output intermediate quantization scale + float m_OutputIntermediateScale; + /// Hidden State zero point + int32_t m_HiddenStateZeroPoint; + /// Hidden State quantization scale + float m_HiddenStateScale; +}; + /// A TransposeConvolution2dDescriptor for the TransposeConvolution2dLayer. struct TransposeConvolution2dDescriptor { diff --git a/include/armnn/DescriptorsFwd.hpp b/include/armnn/DescriptorsFwd.hpp index 1298c1ce01..f0903728dd 100644 --- a/include/armnn/DescriptorsFwd.hpp +++ b/include/armnn/DescriptorsFwd.hpp @@ -29,6 +29,7 @@ struct PadDescriptor; struct PermuteDescriptor; struct Pooling2dDescriptor; struct PreCompiledDescriptor; +struct QLstmDescriptor; struct ReshapeDescriptor; struct ResizeBilinearDescriptor; struct ResizeDescriptor; diff --git a/include/armnn/ILayerSupport.hpp b/include/armnn/ILayerSupport.hpp index 8274b05535..58509c906c 100644 --- a/include/armnn/ILayerSupport.hpp +++ b/include/armnn/ILayerSupport.hpp @@ -284,6 +284,16 @@ public: const TensorInfo& output, Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0; + virtual bool IsQLstmSupported(const TensorInfo& input, + const TensorInfo& previousOutputIn, + const TensorInfo& previousCellStateIn, + const TensorInfo& outputStateOut, + const TensorInfo& cellStateOut, + const TensorInfo& output, + const QLstmDescriptor& descriptor, + const LstmInputParamsInfo& paramsInfo, + Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0; + virtual bool IsQuantizedLstmSupported(const TensorInfo& input, const TensorInfo& previousCellStateIn, const TensorInfo& previousOutputIn, diff --git a/include/armnn/ILayerVisitor.hpp b/include/armnn/ILayerVisitor.hpp index 972915dc0f..530e74f30a 100644 --- a/include/armnn/ILayerVisitor.hpp +++ b/include/armnn/ILayerVisitor.hpp @@ -360,6 +360,16 @@ public: virtual void VisitQuantizeLayer(const IConnectableLayer* layer, const char* name = nullptr) = 0; + /// Function a QLstm layer should call back to when its Accept(ILayerVisitor&) function is invoked. + /// @param layer - pointer to the layer which is calling back to this visit function. + /// @param descriptor - Parameters controlling the operation of the QLstm operation. + /// @param params - The weights and biases for the layer + /// @param name - Optional name for the layer. + virtual void VisitQLstmLayer(const IConnectableLayer* layer, + const QLstmDescriptor& descriptor, + const LstmInputParams& params, + const char* name = nullptr) = 0; + /// Function a QuantizedLstm layer should call back to when its Accept(ILayerVisitor&) function is invoked. /// @param layer - pointer to the layer which is calling back to this visit function. /// @param params - The weights and biases for the Quantized LSTM cell diff --git a/include/armnn/INetwork.hpp b/include/armnn/INetwork.hpp index c976b82c0b..84ecaebfb9 100644 --- a/include/armnn/INetwork.hpp +++ b/include/armnn/INetwork.hpp @@ -555,6 +555,15 @@ public: virtual IConnectableLayer* AddQuantizedLstmLayer(const QuantizedLstmInputParams& params, const char* name = nullptr) = 0; + /// Add a QLstm layer to the network + /// @param descriptor - Parameters for the QLstm operation + /// @param params - Weights and biases for the layer + /// @param name - Optional name for the layer + /// @return - Interface for configuring the layer. + virtual IConnectableLayer* AddQLstmLayer(const QLstmDescriptor& descriptor, + const LstmInputParams& params, + const char* name = nullptr) = 0; + virtual void Accept(ILayerVisitor& visitor) const = 0; protected: diff --git a/include/armnn/LayerVisitorBase.hpp b/include/armnn/LayerVisitorBase.hpp index 9335ff8b0b..95d6bd37bd 100644 --- a/include/armnn/LayerVisitorBase.hpp +++ b/include/armnn/LayerVisitorBase.hpp @@ -183,6 +183,11 @@ public: void VisitQuantizeLayer(const IConnectableLayer*, const char*) override { DefaultPolicy::Apply(__func__); } + void VisitQLstmLayer(const IConnectableLayer*, + const QLstmDescriptor&, + const LstmInputParams&, + const char*) override { DefaultPolicy::Apply(__func__); } + void VisitQuantizedLstmLayer(const IConnectableLayer*, const QuantizedLstmInputParams&, const char*) override { DefaultPolicy::Apply(__func__); } |