From ee18dc8d1725f472850ab0c398fd7cbc4b850891 Mon Sep 17 00:00:00 2001 From: James Conroy Date: Wed, 17 Jul 2019 11:27:46 +0100 Subject: IVGCVSW-3469 Add front end for Quantized LSTM layer * Added new layer QuantizedLstm (Android Q) * Made necessary changes to APIs * Added unit tests Change-Id: I3b9f16b0e7e49f51932cf204c87cb7118798123a Signed-off-by: James Conroy --- include/armnn/ArmNN.hpp | 1 + include/armnn/ILayerSupport.hpp | 11 +- include/armnn/ILayerVisitor.hpp | 8 ++ include/armnn/INetwork.hpp | 14 ++- include/armnn/LayerSupport.hpp | 12 ++ include/armnn/LayerVisitorBase.hpp | 4 + include/armnn/NetworkFwd.hpp | 1 + include/armnn/QuantizedLstmParams.hpp | 218 ++++++++++++++++++++++++++++++++++ 8 files changed, 265 insertions(+), 4 deletions(-) create mode 100644 include/armnn/QuantizedLstmParams.hpp (limited to 'include') diff --git a/include/armnn/ArmNN.hpp b/include/armnn/ArmNN.hpp index 884a3ca844..b18f14c8b7 100644 --- a/include/armnn/ArmNN.hpp +++ b/include/armnn/ArmNN.hpp @@ -11,6 +11,7 @@ #include "IRuntime.hpp" #include "LstmParams.hpp" #include "Optional.hpp" +#include "QuantizedLstmParams.hpp" #include "Tensor.hpp" #include "Types.hpp" #include "TypesUtils.hpp" diff --git a/include/armnn/ILayerSupport.hpp b/include/armnn/ILayerSupport.hpp index 4301f9a196..45360984ff 100644 --- a/include/armnn/ILayerSupport.hpp +++ b/include/armnn/ILayerSupport.hpp @@ -6,8 +6,9 @@ #include #include -#include #include +#include +#include #include #include @@ -228,6 +229,14 @@ public: const TensorInfo& output, Optional reasonIfUnsupported = EmptyOptional()) const = 0; + virtual bool IsQuantizedLstmSupported(const TensorInfo& input, + const TensorInfo& previousCellStateIn, + const TensorInfo& previousOutputIn, + const TensorInfo& cellStateOut, + const TensorInfo& output, + const QuantizedLstmInputParamsInfo& paramsInfo, + Optional reasonIfUnsupported = EmptyOptional()) const = 0; + virtual bool IsReshapeSupported(const TensorInfo& input, const ReshapeDescriptor& descriptor, Optional reasonIfUnsupported = EmptyOptional()) const = 0; diff --git a/include/armnn/ILayerVisitor.hpp b/include/armnn/ILayerVisitor.hpp index 6e5b5463ac..1ccbf98d95 100644 --- a/include/armnn/ILayerVisitor.hpp +++ b/include/armnn/ILayerVisitor.hpp @@ -302,6 +302,14 @@ public: virtual void VisitQuantizeLayer(const IConnectableLayer* layer, const char* name = nullptr) = 0; + /// Function a QuantizedLstm layer should call back to when its Accept(ILayerVisitor&) function is invoked. + /// @param layer - pointer to the layer which is calling back to this visit function. + /// @param params - The weights and biases for the Quantized LSTM cell + /// @param name - Optional name for the layer. + virtual void VisitQuantizedLstmLayer(const IConnectableLayer* layer, + const QuantizedLstmInputParams& params, + const char* name = nullptr) = 0; + /// Function a reshape layer should call back to when its Accept(ILayerVisitor&) function is invoked. /// @param layer - pointer to the layer which is calling back to this visit function. /// @param reshapeDescriptor - Parameters for the reshape operation. diff --git a/include/armnn/INetwork.hpp b/include/armnn/INetwork.hpp index 9e88c9279d..a2ff0dc575 100644 --- a/include/armnn/INetwork.hpp +++ b/include/armnn/INetwork.hpp @@ -356,9 +356,10 @@ public: virtual IConnectableLayer* AddOutputLayer(LayerBindingId id, const char* name = nullptr) = 0; /// Add a Lstm layer to the network - /// @param descriptor Parameters for the Lstm operation - /// @param name Optional name for the layer - /// @return Interface for configuring the layer. + /// @param descriptor - Parameters for the Lstm operation + /// @param params - Weights and biases for the LSTM cell + /// @param name - Optional name for the layer + /// @return - Interface for configuring the layer. virtual IConnectableLayer* AddLstmLayer(const LstmDescriptor& descriptor, const LstmInputParams& params, const char* name = nullptr) = 0; @@ -458,6 +459,13 @@ public: virtual IConnectableLayer* AddStackLayer(const StackDescriptor& descriptor, const char* name = nullptr) = 0; + /// Add a QuantizedLstm layer to the network + /// @param params - The weights and biases for the Quantized LSTM cell + /// @param name - Optional name for the layer + /// @return - Interface for configuring the layer. + virtual IConnectableLayer* AddQuantizedLstmLayer(const QuantizedLstmInputParams& params, + const char* name = nullptr) = 0; + virtual void Accept(ILayerVisitor& visitor) const = 0; protected: diff --git a/include/armnn/LayerSupport.hpp b/include/armnn/LayerSupport.hpp index 6a3f1774bd..2ec086b185 100644 --- a/include/armnn/LayerSupport.hpp +++ b/include/armnn/LayerSupport.hpp @@ -10,6 +10,7 @@ #include #include #include "LstmParams.hpp" +#include "QuantizedLstmParams.hpp" namespace armnn { @@ -290,6 +291,17 @@ bool IsPooling2dSupported(const BackendId& backend, char* reasonIfUnsupported = nullptr, size_t reasonIfUnsupportedMaxLength = 1024); +/// Deprecated in favor of IBackend and ILayerSupport interfaces +bool IsQuantizedLstmSupported(const BackendId& backend, + const TensorInfo& input, + const TensorInfo& previousCellStateIn, + const TensorInfo& previousOutputIn, + const TensorInfo& cellStateOut, + const TensorInfo& output, + const QuantizedLstmInputParamsInfo& paramsInfo, + char* reasonIfUnsupported = nullptr, + size_t reasonIfUnsupportedMaxLength = 1024); + /// Deprecated in favor of IBackend and ILayerSupport interfaces bool IsReshapeSupported(const BackendId& backend, const TensorInfo& input, diff --git a/include/armnn/LayerVisitorBase.hpp b/include/armnn/LayerVisitorBase.hpp index f107e9fb68..8c5464c29e 100644 --- a/include/armnn/LayerVisitorBase.hpp +++ b/include/armnn/LayerVisitorBase.hpp @@ -157,6 +157,10 @@ public: void VisitQuantizeLayer(const IConnectableLayer*, const char*) override { DefaultPolicy::Apply(__func__); } + void VisitQuantizedLstmLayer(const IConnectableLayer*, + const QuantizedLstmInputParams&, + const char*) override { DefaultPolicy::Apply(__func__); } + void VisitReshapeLayer(const IConnectableLayer*, const ReshapeDescriptor&, const char*) override { DefaultPolicy::Apply(__func__); } diff --git a/include/armnn/NetworkFwd.hpp b/include/armnn/NetworkFwd.hpp index 97c5e6eda6..e94a2cccae 100644 --- a/include/armnn/NetworkFwd.hpp +++ b/include/armnn/NetworkFwd.hpp @@ -7,6 +7,7 @@ namespace armnn { struct LstmInputParams; +struct QuantizedLstmInputParams; class INetwork; class IOptimizedNetwork; class Graph; diff --git a/include/armnn/QuantizedLstmParams.hpp b/include/armnn/QuantizedLstmParams.hpp new file mode 100644 index 0000000000..b3033acc9a --- /dev/null +++ b/include/armnn/QuantizedLstmParams.hpp @@ -0,0 +1,218 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// +#pragma once + +#include "TensorFwd.hpp" +#include "Exceptions.hpp" + +namespace armnn +{ + +struct QuantizedLstmInputParams +{ + QuantizedLstmInputParams() + : m_InputToInputWeights(nullptr) + , m_InputToForgetWeights(nullptr) + , m_InputToCellWeights(nullptr) + , m_InputToOutputWeights(nullptr) + + , m_RecurrentToInputWeights(nullptr) + , m_RecurrentToForgetWeights(nullptr) + , m_RecurrentToCellWeights(nullptr) + , m_RecurrentToOutputWeights(nullptr) + + , m_InputGateBias(nullptr) + , m_ForgetGateBias(nullptr) + , m_CellBias(nullptr) + , m_OutputGateBias(nullptr) + { + } + + const ConstTensor* m_InputToInputWeights; + const ConstTensor* m_InputToForgetWeights; + const ConstTensor* m_InputToCellWeights; + const ConstTensor* m_InputToOutputWeights; + + const ConstTensor* m_RecurrentToInputWeights; + const ConstTensor* m_RecurrentToForgetWeights; + const ConstTensor* m_RecurrentToCellWeights; + const ConstTensor* m_RecurrentToOutputWeights; + + const ConstTensor* m_InputGateBias; + const ConstTensor* m_ForgetGateBias; + const ConstTensor* m_CellBias; + const ConstTensor* m_OutputGateBias; + + const ConstTensor& deref(const ConstTensor* tensorPtr) const + { + if (tensorPtr != nullptr) + { + const ConstTensor &temp = *tensorPtr; + return temp; + } + throw InvalidArgumentException("QuantizedLstmInputParams: Can't dereference a null pointer"); + } + + const ConstTensor& get_InputToInputWeights() const + { + return deref(m_InputToInputWeights); + } + + const ConstTensor& get_InputToForgetWeights() const + { + return deref(m_InputToForgetWeights); + } + + const ConstTensor& get_InputToCellWeights() const + { + return deref(m_InputToCellWeights); + } + + const ConstTensor& get_InputToOutputWeights() const + { + return deref(m_InputToOutputWeights); + } + + const ConstTensor& get_RecurrentToInputWeights() const + { + return deref(m_RecurrentToInputWeights); + } + + const ConstTensor& get_RecurrentToForgetWeights() const + { + return deref(m_RecurrentToForgetWeights); + } + + const ConstTensor& get_RecurrentToCellWeights() const + { + return deref(m_RecurrentToCellWeights); + } + + const ConstTensor& get_RecurrentToOutputWeights() const + { + return deref(m_RecurrentToOutputWeights); + } + + const ConstTensor& get_InputGateBias() const + { + return deref(m_InputGateBias); + } + + const ConstTensor& get_ForgetGateBias() const + { + return deref(m_ForgetGateBias); + } + + const ConstTensor& get_CellBias() const + { + return deref(m_CellBias); + } + + const ConstTensor& get_OutputGateBias() const + { + return deref(m_OutputGateBias); + } +}; + +struct QuantizedLstmInputParamsInfo +{ + QuantizedLstmInputParamsInfo() + : m_InputToInputWeights(nullptr) + , m_InputToForgetWeights(nullptr) + , m_InputToCellWeights(nullptr) + , m_InputToOutputWeights(nullptr) + + , m_RecurrentToInputWeights(nullptr) + , m_RecurrentToForgetWeights(nullptr) + , m_RecurrentToCellWeights(nullptr) + , m_RecurrentToOutputWeights(nullptr) + + , m_InputGateBias(nullptr) + , m_ForgetGateBias(nullptr) + , m_CellBias(nullptr) + , m_OutputGateBias(nullptr) + { + } + + const TensorInfo* m_InputToInputWeights; + const TensorInfo* m_InputToForgetWeights; + const TensorInfo* m_InputToCellWeights; + const TensorInfo* m_InputToOutputWeights; + + const TensorInfo* m_RecurrentToInputWeights; + const TensorInfo* m_RecurrentToForgetWeights; + const TensorInfo* m_RecurrentToCellWeights; + const TensorInfo* m_RecurrentToOutputWeights; + + const TensorInfo* m_InputGateBias; + const TensorInfo* m_ForgetGateBias; + const TensorInfo* m_CellBias; + const TensorInfo* m_OutputGateBias; + + + const TensorInfo& deref(const TensorInfo* tensorInfo) const + { + if (tensorInfo != nullptr) + { + const TensorInfo &temp = *tensorInfo; + return temp; + } + throw InvalidArgumentException("Can't dereference a null pointer"); + } + + const TensorInfo& get_InputToInputWeights() const + { + return deref(m_InputToInputWeights); + } + const TensorInfo& get_InputToForgetWeights() const + { + return deref(m_InputToForgetWeights); + } + const TensorInfo& get_InputToCellWeights() const + { + return deref(m_InputToCellWeights); + } + const TensorInfo& get_InputToOutputWeights() const + { + return deref(m_InputToOutputWeights); + } + + const TensorInfo& get_RecurrentToInputWeights() const + { + return deref(m_RecurrentToInputWeights); + } + const TensorInfo& get_RecurrentToForgetWeights() const + { + return deref(m_RecurrentToForgetWeights); + } + const TensorInfo& get_RecurrentToCellWeights() const + { + return deref(m_RecurrentToCellWeights); + } + const TensorInfo& get_RecurrentToOutputWeights() const + { + return deref(m_RecurrentToOutputWeights); + } + + const TensorInfo& get_InputGateBias() const + { + return deref(m_InputGateBias); + } + const TensorInfo& get_ForgetGateBias() const + { + return deref(m_ForgetGateBias); + } + const TensorInfo& get_CellBias() const + { + return deref(m_CellBias); + } + const TensorInfo& get_OutputGateBias() const + { + return deref(m_OutputGateBias); + } +}; + +} // namespace armnn + -- cgit v1.2.1