aboutsummaryrefslogtreecommitdiff
path: root/include
diff options
context:
space:
mode:
authorJames Conroy <james.conroy@arm.com>2019-07-17 11:27:46 +0100
committerAron Virginas-Tar <Aron.Virginas-Tar@arm.com>2019-07-24 10:40:13 +0100
commitee18dc8d1725f472850ab0c398fd7cbc4b850891 (patch)
treeb57738b18781d512f5438ca5154652571393e4e8 /include
parent7b1845206d723a91aec811edaf7cb0cf832dfd25 (diff)
downloadarmnn-ee18dc8d1725f472850ab0c398fd7cbc4b850891.tar.gz
IVGCVSW-3469 Add front end for Quantized LSTM layer
* Added new layer QuantizedLstm (Android Q) * Made necessary changes to APIs * Added unit tests Change-Id: I3b9f16b0e7e49f51932cf204c87cb7118798123a Signed-off-by: James Conroy <james.conroy@arm.com>
Diffstat (limited to 'include')
-rw-r--r--include/armnn/ArmNN.hpp1
-rw-r--r--include/armnn/ILayerSupport.hpp11
-rw-r--r--include/armnn/ILayerVisitor.hpp8
-rw-r--r--include/armnn/INetwork.hpp14
-rw-r--r--include/armnn/LayerSupport.hpp12
-rw-r--r--include/armnn/LayerVisitorBase.hpp4
-rw-r--r--include/armnn/NetworkFwd.hpp1
-rw-r--r--include/armnn/QuantizedLstmParams.hpp218
8 files changed, 265 insertions, 4 deletions
diff --git a/include/armnn/ArmNN.hpp b/include/armnn/ArmNN.hpp
index 884a3ca844..b18f14c8b7 100644
--- a/include/armnn/ArmNN.hpp
+++ b/include/armnn/ArmNN.hpp
@@ -11,6 +11,7 @@
#include "IRuntime.hpp"
#include "LstmParams.hpp"
#include "Optional.hpp"
+#include "QuantizedLstmParams.hpp"
#include "Tensor.hpp"
#include "Types.hpp"
#include "TypesUtils.hpp"
diff --git a/include/armnn/ILayerSupport.hpp b/include/armnn/ILayerSupport.hpp
index 4301f9a196..45360984ff 100644
--- a/include/armnn/ILayerSupport.hpp
+++ b/include/armnn/ILayerSupport.hpp
@@ -6,8 +6,9 @@
#include <armnn/Deprecated.hpp>
#include <armnn/DescriptorsFwd.hpp>
-#include <armnn/Optional.hpp>
#include <armnn/LstmParams.hpp>
+#include <armnn/Optional.hpp>
+#include <armnn/QuantizedLstmParams.hpp>
#include <cctype>
#include <functional>
@@ -228,6 +229,14 @@ public:
const TensorInfo& output,
Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
+ virtual bool IsQuantizedLstmSupported(const TensorInfo& input,
+ const TensorInfo& previousCellStateIn,
+ const TensorInfo& previousOutputIn,
+ const TensorInfo& cellStateOut,
+ const TensorInfo& output,
+ const QuantizedLstmInputParamsInfo& paramsInfo,
+ Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
+
virtual bool IsReshapeSupported(const TensorInfo& input,
const ReshapeDescriptor& descriptor,
Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
diff --git a/include/armnn/ILayerVisitor.hpp b/include/armnn/ILayerVisitor.hpp
index 6e5b5463ac..1ccbf98d95 100644
--- a/include/armnn/ILayerVisitor.hpp
+++ b/include/armnn/ILayerVisitor.hpp
@@ -302,6 +302,14 @@ public:
virtual void VisitQuantizeLayer(const IConnectableLayer* layer,
const char* name = nullptr) = 0;
+ /// Function a QuantizedLstm layer should call back to when its Accept(ILayerVisitor&) function is invoked.
+ /// @param layer - pointer to the layer which is calling back to this visit function.
+ /// @param params - The weights and biases for the Quantized LSTM cell
+ /// @param name - Optional name for the layer.
+ virtual void VisitQuantizedLstmLayer(const IConnectableLayer* layer,
+ const QuantizedLstmInputParams& params,
+ const char* name = nullptr) = 0;
+
/// Function a reshape layer should call back to when its Accept(ILayerVisitor&) function is invoked.
/// @param layer - pointer to the layer which is calling back to this visit function.
/// @param reshapeDescriptor - Parameters for the reshape operation.
diff --git a/include/armnn/INetwork.hpp b/include/armnn/INetwork.hpp
index 9e88c9279d..a2ff0dc575 100644
--- a/include/armnn/INetwork.hpp
+++ b/include/armnn/INetwork.hpp
@@ -356,9 +356,10 @@ public:
virtual IConnectableLayer* AddOutputLayer(LayerBindingId id, const char* name = nullptr) = 0;
/// Add a Lstm layer to the network
- /// @param descriptor Parameters for the Lstm operation
- /// @param name Optional name for the layer
- /// @return Interface for configuring the layer.
+ /// @param descriptor - Parameters for the Lstm operation
+ /// @param params - Weights and biases for the LSTM cell
+ /// @param name - Optional name for the layer
+ /// @return - Interface for configuring the layer.
virtual IConnectableLayer* AddLstmLayer(const LstmDescriptor& descriptor,
const LstmInputParams& params,
const char* name = nullptr) = 0;
@@ -458,6 +459,13 @@ public:
virtual IConnectableLayer* AddStackLayer(const StackDescriptor& descriptor,
const char* name = nullptr) = 0;
+ /// Add a QuantizedLstm layer to the network
+ /// @param params - The weights and biases for the Quantized LSTM cell
+ /// @param name - Optional name for the layer
+ /// @return - Interface for configuring the layer.
+ virtual IConnectableLayer* AddQuantizedLstmLayer(const QuantizedLstmInputParams& params,
+ const char* name = nullptr) = 0;
+
virtual void Accept(ILayerVisitor& visitor) const = 0;
protected:
diff --git a/include/armnn/LayerSupport.hpp b/include/armnn/LayerSupport.hpp
index 6a3f1774bd..2ec086b185 100644
--- a/include/armnn/LayerSupport.hpp
+++ b/include/armnn/LayerSupport.hpp
@@ -10,6 +10,7 @@
#include <armnn/Tensor.hpp>
#include <armnn/Types.hpp>
#include "LstmParams.hpp"
+#include "QuantizedLstmParams.hpp"
namespace armnn
{
@@ -291,6 +292,17 @@ bool IsPooling2dSupported(const BackendId& backend,
size_t reasonIfUnsupportedMaxLength = 1024);
/// Deprecated in favor of IBackend and ILayerSupport interfaces
+bool IsQuantizedLstmSupported(const BackendId& backend,
+ const TensorInfo& input,
+ const TensorInfo& previousCellStateIn,
+ const TensorInfo& previousOutputIn,
+ const TensorInfo& cellStateOut,
+ const TensorInfo& output,
+ const QuantizedLstmInputParamsInfo& paramsInfo,
+ char* reasonIfUnsupported = nullptr,
+ size_t reasonIfUnsupportedMaxLength = 1024);
+
+/// Deprecated in favor of IBackend and ILayerSupport interfaces
bool IsReshapeSupported(const BackendId& backend,
const TensorInfo& input,
const ReshapeDescriptor& descriptor,
diff --git a/include/armnn/LayerVisitorBase.hpp b/include/armnn/LayerVisitorBase.hpp
index f107e9fb68..8c5464c29e 100644
--- a/include/armnn/LayerVisitorBase.hpp
+++ b/include/armnn/LayerVisitorBase.hpp
@@ -157,6 +157,10 @@ public:
void VisitQuantizeLayer(const IConnectableLayer*,
const char*) override { DefaultPolicy::Apply(__func__); }
+ void VisitQuantizedLstmLayer(const IConnectableLayer*,
+ const QuantizedLstmInputParams&,
+ const char*) override { DefaultPolicy::Apply(__func__); }
+
void VisitReshapeLayer(const IConnectableLayer*,
const ReshapeDescriptor&,
const char*) override { DefaultPolicy::Apply(__func__); }
diff --git a/include/armnn/NetworkFwd.hpp b/include/armnn/NetworkFwd.hpp
index 97c5e6eda6..e94a2cccae 100644
--- a/include/armnn/NetworkFwd.hpp
+++ b/include/armnn/NetworkFwd.hpp
@@ -7,6 +7,7 @@
namespace armnn
{
struct LstmInputParams;
+struct QuantizedLstmInputParams;
class INetwork;
class IOptimizedNetwork;
class Graph;
diff --git a/include/armnn/QuantizedLstmParams.hpp b/include/armnn/QuantizedLstmParams.hpp
new file mode 100644
index 0000000000..b3033acc9a
--- /dev/null
+++ b/include/armnn/QuantizedLstmParams.hpp
@@ -0,0 +1,218 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+#pragma once
+
+#include "TensorFwd.hpp"
+#include "Exceptions.hpp"
+
+namespace armnn
+{
+
+struct QuantizedLstmInputParams
+{
+ QuantizedLstmInputParams()
+ : m_InputToInputWeights(nullptr)
+ , m_InputToForgetWeights(nullptr)
+ , m_InputToCellWeights(nullptr)
+ , m_InputToOutputWeights(nullptr)
+
+ , m_RecurrentToInputWeights(nullptr)
+ , m_RecurrentToForgetWeights(nullptr)
+ , m_RecurrentToCellWeights(nullptr)
+ , m_RecurrentToOutputWeights(nullptr)
+
+ , m_InputGateBias(nullptr)
+ , m_ForgetGateBias(nullptr)
+ , m_CellBias(nullptr)
+ , m_OutputGateBias(nullptr)
+ {
+ }
+
+ const ConstTensor* m_InputToInputWeights;
+ const ConstTensor* m_InputToForgetWeights;
+ const ConstTensor* m_InputToCellWeights;
+ const ConstTensor* m_InputToOutputWeights;
+
+ const ConstTensor* m_RecurrentToInputWeights;
+ const ConstTensor* m_RecurrentToForgetWeights;
+ const ConstTensor* m_RecurrentToCellWeights;
+ const ConstTensor* m_RecurrentToOutputWeights;
+
+ const ConstTensor* m_InputGateBias;
+ const ConstTensor* m_ForgetGateBias;
+ const ConstTensor* m_CellBias;
+ const ConstTensor* m_OutputGateBias;
+
+ const ConstTensor& deref(const ConstTensor* tensorPtr) const
+ {
+ if (tensorPtr != nullptr)
+ {
+ const ConstTensor &temp = *tensorPtr;
+ return temp;
+ }
+ throw InvalidArgumentException("QuantizedLstmInputParams: Can't dereference a null pointer");
+ }
+
+ const ConstTensor& get_InputToInputWeights() const
+ {
+ return deref(m_InputToInputWeights);
+ }
+
+ const ConstTensor& get_InputToForgetWeights() const
+ {
+ return deref(m_InputToForgetWeights);
+ }
+
+ const ConstTensor& get_InputToCellWeights() const
+ {
+ return deref(m_InputToCellWeights);
+ }
+
+ const ConstTensor& get_InputToOutputWeights() const
+ {
+ return deref(m_InputToOutputWeights);
+ }
+
+ const ConstTensor& get_RecurrentToInputWeights() const
+ {
+ return deref(m_RecurrentToInputWeights);
+ }
+
+ const ConstTensor& get_RecurrentToForgetWeights() const
+ {
+ return deref(m_RecurrentToForgetWeights);
+ }
+
+ const ConstTensor& get_RecurrentToCellWeights() const
+ {
+ return deref(m_RecurrentToCellWeights);
+ }
+
+ const ConstTensor& get_RecurrentToOutputWeights() const
+ {
+ return deref(m_RecurrentToOutputWeights);
+ }
+
+ const ConstTensor& get_InputGateBias() const
+ {
+ return deref(m_InputGateBias);
+ }
+
+ const ConstTensor& get_ForgetGateBias() const
+ {
+ return deref(m_ForgetGateBias);
+ }
+
+ const ConstTensor& get_CellBias() const
+ {
+ return deref(m_CellBias);
+ }
+
+ const ConstTensor& get_OutputGateBias() const
+ {
+ return deref(m_OutputGateBias);
+ }
+};
+
+struct QuantizedLstmInputParamsInfo
+{
+ QuantizedLstmInputParamsInfo()
+ : m_InputToInputWeights(nullptr)
+ , m_InputToForgetWeights(nullptr)
+ , m_InputToCellWeights(nullptr)
+ , m_InputToOutputWeights(nullptr)
+
+ , m_RecurrentToInputWeights(nullptr)
+ , m_RecurrentToForgetWeights(nullptr)
+ , m_RecurrentToCellWeights(nullptr)
+ , m_RecurrentToOutputWeights(nullptr)
+
+ , m_InputGateBias(nullptr)
+ , m_ForgetGateBias(nullptr)
+ , m_CellBias(nullptr)
+ , m_OutputGateBias(nullptr)
+ {
+ }
+
+ const TensorInfo* m_InputToInputWeights;
+ const TensorInfo* m_InputToForgetWeights;
+ const TensorInfo* m_InputToCellWeights;
+ const TensorInfo* m_InputToOutputWeights;
+
+ const TensorInfo* m_RecurrentToInputWeights;
+ const TensorInfo* m_RecurrentToForgetWeights;
+ const TensorInfo* m_RecurrentToCellWeights;
+ const TensorInfo* m_RecurrentToOutputWeights;
+
+ const TensorInfo* m_InputGateBias;
+ const TensorInfo* m_ForgetGateBias;
+ const TensorInfo* m_CellBias;
+ const TensorInfo* m_OutputGateBias;
+
+
+ const TensorInfo& deref(const TensorInfo* tensorInfo) const
+ {
+ if (tensorInfo != nullptr)
+ {
+ const TensorInfo &temp = *tensorInfo;
+ return temp;
+ }
+ throw InvalidArgumentException("Can't dereference a null pointer");
+ }
+
+ const TensorInfo& get_InputToInputWeights() const
+ {
+ return deref(m_InputToInputWeights);
+ }
+ const TensorInfo& get_InputToForgetWeights() const
+ {
+ return deref(m_InputToForgetWeights);
+ }
+ const TensorInfo& get_InputToCellWeights() const
+ {
+ return deref(m_InputToCellWeights);
+ }
+ const TensorInfo& get_InputToOutputWeights() const
+ {
+ return deref(m_InputToOutputWeights);
+ }
+
+ const TensorInfo& get_RecurrentToInputWeights() const
+ {
+ return deref(m_RecurrentToInputWeights);
+ }
+ const TensorInfo& get_RecurrentToForgetWeights() const
+ {
+ return deref(m_RecurrentToForgetWeights);
+ }
+ const TensorInfo& get_RecurrentToCellWeights() const
+ {
+ return deref(m_RecurrentToCellWeights);
+ }
+ const TensorInfo& get_RecurrentToOutputWeights() const
+ {
+ return deref(m_RecurrentToOutputWeights);
+ }
+
+ const TensorInfo& get_InputGateBias() const
+ {
+ return deref(m_InputGateBias);
+ }
+ const TensorInfo& get_ForgetGateBias() const
+ {
+ return deref(m_ForgetGateBias);
+ }
+ const TensorInfo& get_CellBias() const
+ {
+ return deref(m_CellBias);
+ }
+ const TensorInfo& get_OutputGateBias() const
+ {
+ return deref(m_OutputGateBias);
+ }
+};
+
+} // namespace armnn
+