aboutsummaryrefslogtreecommitdiff
path: root/src/armnn
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn')
-rw-r--r--src/armnn/BackendHelper.cpp8
-rw-r--r--src/armnn/ILayerSupport.cpp68
-rw-r--r--src/armnn/layers/UnidirectionalSequenceLstmLayer.cpp4
3 files changed, 18 insertions, 62 deletions
diff --git a/src/armnn/BackendHelper.cpp b/src/armnn/BackendHelper.cpp
index 056fbb08fa..e2aa67275f 100644
--- a/src/armnn/BackendHelper.cpp
+++ b/src/armnn/BackendHelper.cpp
@@ -1332,16 +1332,14 @@ bool LayerSupportHandle::IsTransposeSupported(const TensorInfo& input,
bool LayerSupportHandle::IsUnidirectionalSequenceLstmSupported(const TensorInfo& input,
const TensorInfo& outputStateIn,
const TensorInfo& cellStateIn,
+ const TensorInfo& outputStateOut,
+ const TensorInfo& cellStateOut,
const TensorInfo& output,
- const Optional<TensorInfo>& hiddenStateOutput,
- const Optional<TensorInfo>& cellStateOutput,
const LstmDescriptor& descriptor,
const LstmInputParamsInfo& paramsInfo,
Optional<std::string&> reasonIfUnsupported)
{
- TensorInfo hiddenStateOutputVal = hiddenStateOutput.has_value() ? hiddenStateOutput.value() : TensorInfo();
- TensorInfo cellStateOutputVal = cellStateOutput.has_value() ? cellStateOutput.value() : TensorInfo();
- TensorInfos infos{input, outputStateIn, cellStateIn, hiddenStateOutputVal, cellStateOutputVal, output};
+ TensorInfos infos{input, outputStateIn, cellStateIn, outputStateOut, cellStateOut, output};
return m_LayerSupport->IsLayerSupported(LayerType::UnidirectionalSequenceLstm,
infos,
diff --git a/src/armnn/ILayerSupport.cpp b/src/armnn/ILayerSupport.cpp
index bf54223414..5366b13088 100644
--- a/src/armnn/ILayerSupport.cpp
+++ b/src/armnn/ILayerSupport.cpp
@@ -488,57 +488,15 @@ bool ILayerSupport::IsLayerSupported(const LayerType& type,
"hiddenStateOutputVal, cellStateOutputVal, output}");
}
auto desc = *(PolymorphicDowncast<const UnidirectionalSequenceLstmDescriptor*>(&descriptor));
-
- bool isHiddenStateOutputOptional = (infos[4] == TensorInfo());
- bool isCellStateOutput = (infos[5] == TensorInfo());
- if (isHiddenStateOutputOptional && isCellStateOutput)
- {
- return IsUnidirectionalSequenceLstmSupported(infos[0],
- infos[1],
- infos[2],
- infos[3],
- EmptyOptional(),
- EmptyOptional(),
- desc,
- lstmParamsInfo.value(),
- reasonIfUnsupported);
- }
- else if (isHiddenStateOutputOptional)
- {
- return IsUnidirectionalSequenceLstmSupported(infos[0],
- infos[1],
- infos[2],
- infos[3],
- EmptyOptional(),
- infos[5],
- desc,
- lstmParamsInfo.value(),
- reasonIfUnsupported);
- }
- else if (isCellStateOutput)
- {
- return IsUnidirectionalSequenceLstmSupported(infos[0],
- infos[1],
- infos[2],
- infos[3],
- infos[4],
- EmptyOptional(),
- desc,
- lstmParamsInfo.value(),
- reasonIfUnsupported);
- }
- else
- {
- return IsUnidirectionalSequenceLstmSupported(infos[0],
- infos[1],
- infos[2],
- infos[3],
- infos[4],
- infos[5],
- desc,
- lstmParamsInfo.value(),
- reasonIfUnsupported);
- }
+ return IsUnidirectionalSequenceLstmSupported(infos[0],
+ infos[1],
+ infos[2],
+ infos[3],
+ infos[4],
+ infos[5],
+ desc,
+ lstmParamsInfo.value(),
+ reasonIfUnsupported);
}
case LayerType::ChannelShuffle:
return IsChannelShuffleSupported(infos[0],
@@ -1285,9 +1243,9 @@ bool ILayerSupport::IsUnidirectionalSequenceLstmSupported(
const TensorInfo& input,
const TensorInfo& outputStateIn,
const TensorInfo& cellStateIn,
+ const TensorInfo& outputStateOut,
+ const TensorInfo& cellStateOut,
const TensorInfo& output,
- const Optional<TensorInfo>& hiddenStateOutput,
- const Optional<TensorInfo>& cellStateOutput,
const LstmDescriptor& descriptor,
const LstmInputParamsInfo& paramsInfo,
Optional<std::string&> reasonIfUnsupported) const
@@ -1295,9 +1253,9 @@ bool ILayerSupport::IsUnidirectionalSequenceLstmSupported(
IgnoreUnused(input,
outputStateIn,
cellStateIn,
+ outputStateOut,
+ cellStateOut,
output,
- hiddenStateOutput,
- cellStateOutput,
descriptor,
paramsInfo,
reasonIfUnsupported);
diff --git a/src/armnn/layers/UnidirectionalSequenceLstmLayer.cpp b/src/armnn/layers/UnidirectionalSequenceLstmLayer.cpp
index 199961449e..e5f89bd017 100644
--- a/src/armnn/layers/UnidirectionalSequenceLstmLayer.cpp
+++ b/src/armnn/layers/UnidirectionalSequenceLstmLayer.cpp
@@ -15,7 +15,7 @@ namespace armnn
{
UnidirectionalSequenceLstmLayer::UnidirectionalSequenceLstmLayer(const LstmDescriptor& param, const char* name)
- : LayerWithParameters(3, 1, LayerType::UnidirectionalSequenceLstm, param, name)
+ : LayerWithParameters(3, 3, LayerType::UnidirectionalSequenceLstm, param, name)
{
}
@@ -171,7 +171,7 @@ void UnidirectionalSequenceLstmLayer::ValidateTensorShapesFromInputs()
{
VerifyLayerConnections(3, CHECK_LOCATION());
- const TensorShape& outputShape = GetOutputSlot(0).GetTensorInfo().GetShape();
+ const TensorShape& outputShape = GetOutputSlot(2).GetTensorInfo().GetShape();
VerifyShapeInferenceType(outputShape, m_ShapeInferenceMethod);