diff options
Diffstat (limited to 'src/armnn')
-rw-r--r-- | src/armnn/BackendHelper.cpp | 8 | ||||
-rw-r--r-- | src/armnn/ILayerSupport.cpp | 68 | ||||
-rw-r--r-- | src/armnn/layers/UnidirectionalSequenceLstmLayer.cpp | 4 |
3 files changed, 18 insertions, 62 deletions
diff --git a/src/armnn/BackendHelper.cpp b/src/armnn/BackendHelper.cpp index 056fbb08fa..e2aa67275f 100644 --- a/src/armnn/BackendHelper.cpp +++ b/src/armnn/BackendHelper.cpp @@ -1332,16 +1332,14 @@ bool LayerSupportHandle::IsTransposeSupported(const TensorInfo& input, bool LayerSupportHandle::IsUnidirectionalSequenceLstmSupported(const TensorInfo& input, const TensorInfo& outputStateIn, const TensorInfo& cellStateIn, + const TensorInfo& outputStateOut, + const TensorInfo& cellStateOut, const TensorInfo& output, - const Optional<TensorInfo>& hiddenStateOutput, - const Optional<TensorInfo>& cellStateOutput, const LstmDescriptor& descriptor, const LstmInputParamsInfo& paramsInfo, Optional<std::string&> reasonIfUnsupported) { - TensorInfo hiddenStateOutputVal = hiddenStateOutput.has_value() ? hiddenStateOutput.value() : TensorInfo(); - TensorInfo cellStateOutputVal = cellStateOutput.has_value() ? cellStateOutput.value() : TensorInfo(); - TensorInfos infos{input, outputStateIn, cellStateIn, hiddenStateOutputVal, cellStateOutputVal, output}; + TensorInfos infos{input, outputStateIn, cellStateIn, outputStateOut, cellStateOut, output}; return m_LayerSupport->IsLayerSupported(LayerType::UnidirectionalSequenceLstm, infos, diff --git a/src/armnn/ILayerSupport.cpp b/src/armnn/ILayerSupport.cpp index bf54223414..5366b13088 100644 --- a/src/armnn/ILayerSupport.cpp +++ b/src/armnn/ILayerSupport.cpp @@ -488,57 +488,15 @@ bool ILayerSupport::IsLayerSupported(const LayerType& type, "hiddenStateOutputVal, cellStateOutputVal, output}"); } auto desc = *(PolymorphicDowncast<const UnidirectionalSequenceLstmDescriptor*>(&descriptor)); - - bool isHiddenStateOutputOptional = (infos[4] == TensorInfo()); - bool isCellStateOutput = (infos[5] == TensorInfo()); - if (isHiddenStateOutputOptional && isCellStateOutput) - { - return IsUnidirectionalSequenceLstmSupported(infos[0], - infos[1], - infos[2], - infos[3], - EmptyOptional(), - EmptyOptional(), - desc, - lstmParamsInfo.value(), - reasonIfUnsupported); - } - else if (isHiddenStateOutputOptional) - { - return IsUnidirectionalSequenceLstmSupported(infos[0], - infos[1], - infos[2], - infos[3], - EmptyOptional(), - infos[5], - desc, - lstmParamsInfo.value(), - reasonIfUnsupported); - } - else if (isCellStateOutput) - { - return IsUnidirectionalSequenceLstmSupported(infos[0], - infos[1], - infos[2], - infos[3], - infos[4], - EmptyOptional(), - desc, - lstmParamsInfo.value(), - reasonIfUnsupported); - } - else - { - return IsUnidirectionalSequenceLstmSupported(infos[0], - infos[1], - infos[2], - infos[3], - infos[4], - infos[5], - desc, - lstmParamsInfo.value(), - reasonIfUnsupported); - } + return IsUnidirectionalSequenceLstmSupported(infos[0], + infos[1], + infos[2], + infos[3], + infos[4], + infos[5], + desc, + lstmParamsInfo.value(), + reasonIfUnsupported); } case LayerType::ChannelShuffle: return IsChannelShuffleSupported(infos[0], @@ -1285,9 +1243,9 @@ bool ILayerSupport::IsUnidirectionalSequenceLstmSupported( const TensorInfo& input, const TensorInfo& outputStateIn, const TensorInfo& cellStateIn, + const TensorInfo& outputStateOut, + const TensorInfo& cellStateOut, const TensorInfo& output, - const Optional<TensorInfo>& hiddenStateOutput, - const Optional<TensorInfo>& cellStateOutput, const LstmDescriptor& descriptor, const LstmInputParamsInfo& paramsInfo, Optional<std::string&> reasonIfUnsupported) const @@ -1295,9 +1253,9 @@ bool ILayerSupport::IsUnidirectionalSequenceLstmSupported( IgnoreUnused(input, outputStateIn, cellStateIn, + outputStateOut, + cellStateOut, output, - hiddenStateOutput, - cellStateOutput, descriptor, paramsInfo, reasonIfUnsupported); diff --git a/src/armnn/layers/UnidirectionalSequenceLstmLayer.cpp b/src/armnn/layers/UnidirectionalSequenceLstmLayer.cpp index 199961449e..e5f89bd017 100644 --- a/src/armnn/layers/UnidirectionalSequenceLstmLayer.cpp +++ b/src/armnn/layers/UnidirectionalSequenceLstmLayer.cpp @@ -15,7 +15,7 @@ namespace armnn { UnidirectionalSequenceLstmLayer::UnidirectionalSequenceLstmLayer(const LstmDescriptor& param, const char* name) - : LayerWithParameters(3, 1, LayerType::UnidirectionalSequenceLstm, param, name) + : LayerWithParameters(3, 3, LayerType::UnidirectionalSequenceLstm, param, name) { } @@ -171,7 +171,7 @@ void UnidirectionalSequenceLstmLayer::ValidateTensorShapesFromInputs() { VerifyLayerConnections(3, CHECK_LOCATION()); - const TensorShape& outputShape = GetOutputSlot(0).GetTensorInfo().GetShape(); + const TensorShape& outputShape = GetOutputSlot(2).GetTensorInfo().GetShape(); VerifyShapeInferenceType(outputShape, m_ShapeInferenceMethod); |