aboutsummaryrefslogtreecommitdiff
path: root/include
diff options
context:
space:
mode:
authorJames Conroy <james.conroy@arm.com>2020-03-20 08:49:33 +0000
committerJames Conroy <james.conroy@arm.com>2020-03-20 14:53:44 +0000
commit586a9aac99312eb9cb304cbbd18cec46b9158e23 (patch)
tree6d620eae6dcfb920ac04eae43424548dc602a1eb /include
parentc94d3f7107b84b586791aa096f8641e6efa18c90 (diff)
downloadarmnn-586a9aac99312eb9cb304cbbd18cec46b9158e23.tar.gz
IVGCVSW-4549 Add front end for new QLSTM layer
* Added new layer QLstm (Android R HAL 1.3) * Made necessary updates to APIs * Added unit tests * This layer is functionally equivalent to the original unquantized LSTM layer with some additonal quantization features added. Due to this, original LstmParams are used for this layer. Signed-off-by: James Conroy <james.conroy@arm.com> Change-Id: I5b7f2d2fb6e17e81573b41a31bc55f49ae79608f
Diffstat (limited to 'include')
-rw-r--r--include/armnn/Descriptors.hpp60
-rw-r--r--include/armnn/DescriptorsFwd.hpp1
-rw-r--r--include/armnn/ILayerSupport.hpp10
-rw-r--r--include/armnn/ILayerVisitor.hpp10
-rw-r--r--include/armnn/INetwork.hpp9
-rw-r--r--include/armnn/LayerVisitorBase.hpp5
6 files changed, 95 insertions, 0 deletions
diff --git a/include/armnn/Descriptors.hpp b/include/armnn/Descriptors.hpp
index 57917261d4..95eeaaa420 100644
--- a/include/armnn/Descriptors.hpp
+++ b/include/armnn/Descriptors.hpp
@@ -1083,6 +1083,66 @@ struct PreCompiledDescriptor
unsigned int m_NumOutputSlots;
};
+/// A QLstmDescriptor for the QLstmLayer.
+struct QLstmDescriptor
+{
+ QLstmDescriptor()
+ : m_CellClip(0.0)
+ , m_ProjectionClip(0.0)
+ , m_CifgEnabled(true)
+ , m_PeepholeEnabled(false)
+ , m_ProjectionEnabled(false)
+ , m_LayerNormEnabled(false)
+ , m_InputIntermediateScale(0.0)
+ , m_ForgetIntermediateScale(0.0)
+ , m_CellIntermediateScale(0.0)
+ , m_OutputIntermediateScale(0.0)
+ , m_HiddenStateZeroPoint(0)
+ , m_HiddenStateScale(0.0)
+ {}
+
+ bool operator ==(const QLstmDescriptor& rhs) const
+ {
+ return m_CellClip == rhs.m_CellClip &&
+ m_ProjectionClip == rhs.m_ProjectionClip &&
+ m_CifgEnabled == rhs.m_CifgEnabled &&
+ m_PeepholeEnabled == rhs.m_PeepholeEnabled &&
+ m_ProjectionEnabled == rhs.m_ProjectionEnabled &&
+ m_LayerNormEnabled == rhs.m_LayerNormEnabled &&
+ m_InputIntermediateScale == rhs.m_InputIntermediateScale &&
+ m_ForgetIntermediateScale == rhs.m_ForgetIntermediateScale &&
+ m_CellIntermediateScale == rhs.m_CellIntermediateScale &&
+ m_OutputIntermediateScale == rhs.m_OutputIntermediateScale &&
+ m_HiddenStateZeroPoint == rhs.m_HiddenStateZeroPoint &&
+ m_HiddenStateScale == rhs.m_HiddenStateScale;
+ }
+
+ /// Clipping threshold value for the cell state
+ float m_CellClip;
+ /// Clipping threshold value for the projection
+ float m_ProjectionClip;
+ /// Enable/disable CIFG (coupled input & forget gate).
+ bool m_CifgEnabled;
+ /// Enable/disable peephole
+ bool m_PeepholeEnabled;
+ /// Enable/disable the projection layer
+ bool m_ProjectionEnabled;
+ /// Enable/disable layer normalization
+ bool m_LayerNormEnabled;
+ /// Input intermediate quantization scale
+ float m_InputIntermediateScale;
+ /// Forget intermediate quantization scale
+ float m_ForgetIntermediateScale;
+ /// Cell intermediate quantization scale
+ float m_CellIntermediateScale;
+ /// Output intermediate quantization scale
+ float m_OutputIntermediateScale;
+ /// Hidden State zero point
+ int32_t m_HiddenStateZeroPoint;
+ /// Hidden State quantization scale
+ float m_HiddenStateScale;
+};
+
/// A TransposeConvolution2dDescriptor for the TransposeConvolution2dLayer.
struct TransposeConvolution2dDescriptor
{
diff --git a/include/armnn/DescriptorsFwd.hpp b/include/armnn/DescriptorsFwd.hpp
index 1298c1ce01..f0903728dd 100644
--- a/include/armnn/DescriptorsFwd.hpp
+++ b/include/armnn/DescriptorsFwd.hpp
@@ -29,6 +29,7 @@ struct PadDescriptor;
struct PermuteDescriptor;
struct Pooling2dDescriptor;
struct PreCompiledDescriptor;
+struct QLstmDescriptor;
struct ReshapeDescriptor;
struct ResizeBilinearDescriptor;
struct ResizeDescriptor;
diff --git a/include/armnn/ILayerSupport.hpp b/include/armnn/ILayerSupport.hpp
index 8274b05535..58509c906c 100644
--- a/include/armnn/ILayerSupport.hpp
+++ b/include/armnn/ILayerSupport.hpp
@@ -284,6 +284,16 @@ public:
const TensorInfo& output,
Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
+ virtual bool IsQLstmSupported(const TensorInfo& input,
+ const TensorInfo& previousOutputIn,
+ const TensorInfo& previousCellStateIn,
+ const TensorInfo& outputStateOut,
+ const TensorInfo& cellStateOut,
+ const TensorInfo& output,
+ const QLstmDescriptor& descriptor,
+ const LstmInputParamsInfo& paramsInfo,
+ Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
+
virtual bool IsQuantizedLstmSupported(const TensorInfo& input,
const TensorInfo& previousCellStateIn,
const TensorInfo& previousOutputIn,
diff --git a/include/armnn/ILayerVisitor.hpp b/include/armnn/ILayerVisitor.hpp
index 972915dc0f..530e74f30a 100644
--- a/include/armnn/ILayerVisitor.hpp
+++ b/include/armnn/ILayerVisitor.hpp
@@ -360,6 +360,16 @@ public:
virtual void VisitQuantizeLayer(const IConnectableLayer* layer,
const char* name = nullptr) = 0;
+ /// Function a QLstm layer should call back to when its Accept(ILayerVisitor&) function is invoked.
+ /// @param layer - pointer to the layer which is calling back to this visit function.
+ /// @param descriptor - Parameters controlling the operation of the QLstm operation.
+ /// @param params - The weights and biases for the layer
+ /// @param name - Optional name for the layer.
+ virtual void VisitQLstmLayer(const IConnectableLayer* layer,
+ const QLstmDescriptor& descriptor,
+ const LstmInputParams& params,
+ const char* name = nullptr) = 0;
+
/// Function a QuantizedLstm layer should call back to when its Accept(ILayerVisitor&) function is invoked.
/// @param layer - pointer to the layer which is calling back to this visit function.
/// @param params - The weights and biases for the Quantized LSTM cell
diff --git a/include/armnn/INetwork.hpp b/include/armnn/INetwork.hpp
index c976b82c0b..84ecaebfb9 100644
--- a/include/armnn/INetwork.hpp
+++ b/include/armnn/INetwork.hpp
@@ -555,6 +555,15 @@ public:
virtual IConnectableLayer* AddQuantizedLstmLayer(const QuantizedLstmInputParams& params,
const char* name = nullptr) = 0;
+ /// Add a QLstm layer to the network
+ /// @param descriptor - Parameters for the QLstm operation
+ /// @param params - Weights and biases for the layer
+ /// @param name - Optional name for the layer
+ /// @return - Interface for configuring the layer.
+ virtual IConnectableLayer* AddQLstmLayer(const QLstmDescriptor& descriptor,
+ const LstmInputParams& params,
+ const char* name = nullptr) = 0;
+
virtual void Accept(ILayerVisitor& visitor) const = 0;
protected:
diff --git a/include/armnn/LayerVisitorBase.hpp b/include/armnn/LayerVisitorBase.hpp
index 9335ff8b0b..95d6bd37bd 100644
--- a/include/armnn/LayerVisitorBase.hpp
+++ b/include/armnn/LayerVisitorBase.hpp
@@ -183,6 +183,11 @@ public:
void VisitQuantizeLayer(const IConnectableLayer*,
const char*) override { DefaultPolicy::Apply(__func__); }
+ void VisitQLstmLayer(const IConnectableLayer*,
+ const QLstmDescriptor&,
+ const LstmInputParams&,
+ const char*) override { DefaultPolicy::Apply(__func__); }
+
void VisitQuantizedLstmLayer(const IConnectableLayer*,
const QuantizedLstmInputParams&,
const char*) override { DefaultPolicy::Apply(__func__); }