diff options
Diffstat (limited to 'src/armnn/ILayerSupport.cpp')
-rw-r--r-- | src/armnn/ILayerSupport.cpp | 68 |
1 files changed, 13 insertions, 55 deletions
diff --git a/src/armnn/ILayerSupport.cpp b/src/armnn/ILayerSupport.cpp index bf54223414..5366b13088 100644 --- a/src/armnn/ILayerSupport.cpp +++ b/src/armnn/ILayerSupport.cpp @@ -488,57 +488,15 @@ bool ILayerSupport::IsLayerSupported(const LayerType& type, "hiddenStateOutputVal, cellStateOutputVal, output}"); } auto desc = *(PolymorphicDowncast<const UnidirectionalSequenceLstmDescriptor*>(&descriptor)); - - bool isHiddenStateOutputOptional = (infos[4] == TensorInfo()); - bool isCellStateOutput = (infos[5] == TensorInfo()); - if (isHiddenStateOutputOptional && isCellStateOutput) - { - return IsUnidirectionalSequenceLstmSupported(infos[0], - infos[1], - infos[2], - infos[3], - EmptyOptional(), - EmptyOptional(), - desc, - lstmParamsInfo.value(), - reasonIfUnsupported); - } - else if (isHiddenStateOutputOptional) - { - return IsUnidirectionalSequenceLstmSupported(infos[0], - infos[1], - infos[2], - infos[3], - EmptyOptional(), - infos[5], - desc, - lstmParamsInfo.value(), - reasonIfUnsupported); - } - else if (isCellStateOutput) - { - return IsUnidirectionalSequenceLstmSupported(infos[0], - infos[1], - infos[2], - infos[3], - infos[4], - EmptyOptional(), - desc, - lstmParamsInfo.value(), - reasonIfUnsupported); - } - else - { - return IsUnidirectionalSequenceLstmSupported(infos[0], - infos[1], - infos[2], - infos[3], - infos[4], - infos[5], - desc, - lstmParamsInfo.value(), - reasonIfUnsupported); - } + return IsUnidirectionalSequenceLstmSupported(infos[0], + infos[1], + infos[2], + infos[3], + infos[4], + infos[5], + desc, + lstmParamsInfo.value(), + reasonIfUnsupported); } case LayerType::ChannelShuffle: return IsChannelShuffleSupported(infos[0], @@ -1285,9 +1243,9 @@ bool ILayerSupport::IsUnidirectionalSequenceLstmSupported( const TensorInfo& input, const TensorInfo& outputStateIn, const TensorInfo& cellStateIn, + const TensorInfo& outputStateOut, + const TensorInfo& cellStateOut, const TensorInfo& output, - const Optional<TensorInfo>& hiddenStateOutput, - const Optional<TensorInfo>& cellStateOutput, const LstmDescriptor& descriptor, const LstmInputParamsInfo& paramsInfo, Optional<std::string&> reasonIfUnsupported) const @@ -1295,9 +1253,9 @@ bool ILayerSupport::IsUnidirectionalSequenceLstmSupported( IgnoreUnused(input, outputStateIn, cellStateIn, + outputStateOut, + cellStateOut, output, - hiddenStateOutput, - cellStateOutput, descriptor, paramsInfo, reasonIfUnsupported); |