aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/ILayerSupport.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/ILayerSupport.cpp')
-rw-r--r--src/armnn/ILayerSupport.cpp68
1 files changed, 13 insertions, 55 deletions
diff --git a/src/armnn/ILayerSupport.cpp b/src/armnn/ILayerSupport.cpp
index bf54223414..5366b13088 100644
--- a/src/armnn/ILayerSupport.cpp
+++ b/src/armnn/ILayerSupport.cpp
@@ -488,57 +488,15 @@ bool ILayerSupport::IsLayerSupported(const LayerType& type,
"hiddenStateOutputVal, cellStateOutputVal, output}");
}
auto desc = *(PolymorphicDowncast<const UnidirectionalSequenceLstmDescriptor*>(&descriptor));
-
- bool isHiddenStateOutputOptional = (infos[4] == TensorInfo());
- bool isCellStateOutput = (infos[5] == TensorInfo());
- if (isHiddenStateOutputOptional && isCellStateOutput)
- {
- return IsUnidirectionalSequenceLstmSupported(infos[0],
- infos[1],
- infos[2],
- infos[3],
- EmptyOptional(),
- EmptyOptional(),
- desc,
- lstmParamsInfo.value(),
- reasonIfUnsupported);
- }
- else if (isHiddenStateOutputOptional)
- {
- return IsUnidirectionalSequenceLstmSupported(infos[0],
- infos[1],
- infos[2],
- infos[3],
- EmptyOptional(),
- infos[5],
- desc,
- lstmParamsInfo.value(),
- reasonIfUnsupported);
- }
- else if (isCellStateOutput)
- {
- return IsUnidirectionalSequenceLstmSupported(infos[0],
- infos[1],
- infos[2],
- infos[3],
- infos[4],
- EmptyOptional(),
- desc,
- lstmParamsInfo.value(),
- reasonIfUnsupported);
- }
- else
- {
- return IsUnidirectionalSequenceLstmSupported(infos[0],
- infos[1],
- infos[2],
- infos[3],
- infos[4],
- infos[5],
- desc,
- lstmParamsInfo.value(),
- reasonIfUnsupported);
- }
+ return IsUnidirectionalSequenceLstmSupported(infos[0],
+ infos[1],
+ infos[2],
+ infos[3],
+ infos[4],
+ infos[5],
+ desc,
+ lstmParamsInfo.value(),
+ reasonIfUnsupported);
}
case LayerType::ChannelShuffle:
return IsChannelShuffleSupported(infos[0],
@@ -1285,9 +1243,9 @@ bool ILayerSupport::IsUnidirectionalSequenceLstmSupported(
const TensorInfo& input,
const TensorInfo& outputStateIn,
const TensorInfo& cellStateIn,
+ const TensorInfo& outputStateOut,
+ const TensorInfo& cellStateOut,
const TensorInfo& output,
- const Optional<TensorInfo>& hiddenStateOutput,
- const Optional<TensorInfo>& cellStateOutput,
const LstmDescriptor& descriptor,
const LstmInputParamsInfo& paramsInfo,
Optional<std::string&> reasonIfUnsupported) const
@@ -1295,9 +1253,9 @@ bool ILayerSupport::IsUnidirectionalSequenceLstmSupported(
IgnoreUnused(input,
outputStateIn,
cellStateIn,
+ outputStateOut,
+ cellStateOut,
output,
- hiddenStateOutput,
- cellStateOutput,
descriptor,
paramsInfo,
reasonIfUnsupported);