aboutsummaryrefslogtreecommitdiff
path: root/include
diff options
context:
space:
mode:
authorNarumol Prangnawarat <narumol.prangnawarat@arm.com>2021-07-15 16:16:25 +0100
committerNarumol Prangnawarat <narumol.prangnawarat@arm.com>2021-07-22 18:29:55 +0100
commit8ed39ae450a077c7e4d672b5f05ff1d68ee67aab (patch)
tree31a1cf006e50db54f3e7a605825c8e9e3f9d689e /include
parent15fcc7ed3163c9d4b1856955271854198c3c2696 (diff)
downloadarmnn-8ed39ae450a077c7e4d672b5f05ff1d68ee67aab.tar.gz
MLCE-530 Add front end support for UnidirectionalSequenceLstm on ArmNN
Signed-off-by: Narumol Prangnawarat <narumol.prangnawarat@arm.com> Change-Id: I57bcbdec3eb0155f41af0fe7d6abf9bac2ec86eb
Diffstat (limited to 'include')
-rw-r--r--include/armnn/BackendHelper.hpp11
-rw-r--r--include/armnn/Descriptors.hpp8
-rw-r--r--include/armnn/DescriptorsFwd.hpp1
-rw-r--r--include/armnn/INetwork.hpp9
-rw-r--r--include/armnn/Types.hpp9
-rw-r--r--include/armnn/backends/ILayerSupport.hpp11
6 files changed, 45 insertions, 4 deletions
diff --git a/include/armnn/BackendHelper.hpp b/include/armnn/BackendHelper.hpp
index 093f822040..dee3b48b81 100644
--- a/include/armnn/BackendHelper.hpp
+++ b/include/armnn/BackendHelper.hpp
@@ -433,6 +433,17 @@ public:
const TransposeDescriptor& descriptor,
Optional<std::string&> reasonIfUnsupported = EmptyOptional());
+ bool IsUnidirectionalSequenceLstmSupported(
+ const TensorInfo& input,
+ const TensorInfo& outputStateIn,
+ const TensorInfo& cellStateIn,
+ const TensorInfo& output,
+ const Optional<TensorInfo>& hiddenStateOutput,
+ const Optional<TensorInfo>& cellStateOutput,
+ const LstmDescriptor& descriptor,
+ const LstmInputParamsInfo& paramsInfo,
+ Optional<std::string&> reasonIfUnsupported = EmptyOptional());
+
private:
std::shared_ptr<ILayerSupport> m_LayerSupport;
const BackendId m_BackendId;
diff --git a/include/armnn/Descriptors.hpp b/include/armnn/Descriptors.hpp
index 683ef7ac98..bcee902d75 100644
--- a/include/armnn/Descriptors.hpp
+++ b/include/armnn/Descriptors.hpp
@@ -926,6 +926,7 @@ struct LstmDescriptor : BaseDescriptor
, m_PeepholeEnabled(false)
, m_ProjectionEnabled(false)
, m_LayerNormEnabled(false)
+ , m_TimeMajor(true)
{}
bool operator ==(const LstmDescriptor& rhs) const
@@ -935,7 +936,8 @@ struct LstmDescriptor : BaseDescriptor
m_ClippingThresProj == rhs.m_ClippingThresProj &&
m_CifgEnabled == rhs.m_CifgEnabled &&
m_PeepholeEnabled == rhs.m_PeepholeEnabled &&
- m_LayerNormEnabled == rhs.m_LayerNormEnabled;
+ m_LayerNormEnabled == rhs.m_LayerNormEnabled &&
+ m_TimeMajor == rhs.m_TimeMajor;
}
/// @brief The activation function to use.
@@ -953,8 +955,12 @@ struct LstmDescriptor : BaseDescriptor
bool m_ProjectionEnabled;
/// Enable/disable layer normalization
bool m_LayerNormEnabled;
+ /// Enable/disable time major
+ bool m_TimeMajor;
};
+using UnidirectionalSequenceLstmDescriptor = LstmDescriptor;
+
/// A MeanDescriptor for the MeanLayer.
struct MeanDescriptor : BaseDescriptor
{
diff --git a/include/armnn/DescriptorsFwd.hpp b/include/armnn/DescriptorsFwd.hpp
index 9b22644c7b..3b43c42d23 100644
--- a/include/armnn/DescriptorsFwd.hpp
+++ b/include/armnn/DescriptorsFwd.hpp
@@ -55,5 +55,6 @@ using LogSoftmaxDescriptor = SoftmaxDescriptor;
/// MergerDescriptor is deprecated, use ConcatDescriptor instead
using MergerDescriptor = OriginsDescriptor;
using SplitterDescriptor = ViewsDescriptor;
+using UnidirectionalSequenceLstmDescriptor = LstmDescriptor;
} // namespace armnn
diff --git a/include/armnn/INetwork.hpp b/include/armnn/INetwork.hpp
index b40db62a59..865d1291a9 100644
--- a/include/armnn/INetwork.hpp
+++ b/include/armnn/INetwork.hpp
@@ -691,6 +691,15 @@ public:
IConnectableLayer* AddLogicalBinaryLayer(const LogicalBinaryDescriptor& descriptor,
const char* name = nullptr);
+ /// Add a UnidirectionalSequenceLstm layer to the network
+ /// @param descriptor - Parameters for the UnidirectionalSequenceLstm operation
+ /// @param params - Weights and biases for the UnidirectionalSequenceLstm
+ /// @param name - Optional name for the layer
+ /// @return - Interface for configuring the layer.
+ IConnectableLayer* AddUnidirectionalSequenceLstmLayer(const UnidirectionalSequenceLstmDescriptor& descriptor,
+ const LstmInputParams& params,
+ const char* name = nullptr);
+
void Accept(ILayerVisitor& visitor) const;
void ExecuteStrategy(IStrategy& strategy) const;
diff --git a/include/armnn/Types.hpp b/include/armnn/Types.hpp
index e7c17608ca..056aa83d2f 100644
--- a/include/armnn/Types.hpp
+++ b/include/armnn/Types.hpp
@@ -333,7 +333,6 @@ using InferenceTimingPair = std::pair<HighResolutionClock, HighResolutionClock>;
X(ArgMinMax) \
X(BatchNormalization) \
X(BatchToSpaceNd) \
- X(Cast) \
X(Comparison) \
X(Concat) \
X(Constant) \
@@ -382,7 +381,6 @@ using InferenceTimingPair = std::pair<HighResolutionClock, HighResolutionClock>;
X(Rank) \
X(Resize) \
X(Reduce) \
- X(Shape) \
X(Slice) \
X(Softmax) \
X(SpaceToBatchNd) \
@@ -396,6 +394,11 @@ using InferenceTimingPair = std::pair<HighResolutionClock, HighResolutionClock>;
X(Transpose) \
X(TransposeConvolution2d) \
X(Unmap) \
+ X(Cast) \
+ X(Shape) \
+ X(UnidirectionalSequenceLstm) \
+
+// New layers should be added at last to minimize instability.
/// When adding a new layer, adapt also the LastLayer enum value in the
/// enum class LayerType below
@@ -405,7 +408,7 @@ enum class LayerType
LIST_OF_LAYER_TYPE
#undef X
FirstLayer = Activation,
- LastLayer = Unmap
+ LastLayer = UnidirectionalSequenceLstm
};
const char* GetLayerTypeAsCString(LayerType type);
diff --git a/include/armnn/backends/ILayerSupport.hpp b/include/armnn/backends/ILayerSupport.hpp
index 462668d738..7ba565a138 100644
--- a/include/armnn/backends/ILayerSupport.hpp
+++ b/include/armnn/backends/ILayerSupport.hpp
@@ -424,6 +424,17 @@ public:
const TransposeDescriptor& descriptor,
Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
+ virtual bool IsUnidirectionalSequenceLstmSupported(
+ const TensorInfo& input,
+ const TensorInfo& outputStateIn,
+ const TensorInfo& cellStateIn,
+ const TensorInfo& output,
+ const Optional<TensorInfo>& hiddenStateOutput,
+ const Optional<TensorInfo>& cellStateOutput,
+ const LstmDescriptor& descriptor,
+ const LstmInputParamsInfo& paramsInfo,
+ Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
+
}; // class ILayerSupport
using ILayerSupportSharedPtr = std::shared_ptr<ILayerSupport>;