From c0800a32634db3e78d92493402492ff8426a91de Mon Sep 17 00:00:00 2001 From: Mike Kelly Date: Wed, 15 Jun 2022 10:57:52 +0100 Subject: GitHub 653: Segfault when parsing Unidirectional Sequence LSTM * Fixed Segfault when parsing Unidirectional Sequence LSTM Signed-off-by: Mike Kelly Change-Id: Ic69a4190c60ef595be64bc2c356e540319381b7e --- src/armnnTfLiteParser/TfLiteParser.cpp | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp index 49f1f9f856..479fc4f474 100644 --- a/src/armnnTfLiteParser/TfLiteParser.cpp +++ b/src/armnnTfLiteParser/TfLiteParser.cpp @@ -3346,7 +3346,7 @@ void TfLiteParserImpl::ParseUnidirectionalSequenceLSTM(size_t subgraphIndex, siz || params.m_OutputLayerNormWeights != nullptr); desc.m_TimeMajor = nodeParams->time_major; - if (desc.m_LayerNormEnabled) + if (operatorPtr->intermediates.size() > 3 && desc.m_LayerNormEnabled) { auto inputIntermediate = CreateConstTensorPtr(subgraphPtr->tensors[operatorPtr->intermediates[0]].get(), inputTensorInfo).first; @@ -3377,12 +3377,14 @@ void TfLiteParserImpl::ParseUnidirectionalSequenceLSTM(size_t subgraphIndex, siz desc.m_OutputIntermediateScale = defaultIntermediate; } - auto hiddentensor = CreateConstTensorPtr(subgraphPtr->tensors[operatorPtr->intermediates[4]].get(), - inputTensorInfo).first; - - desc.m_HiddenStateScale = hiddentensor->GetInfo().GetQuantizationScale(); - desc.m_HiddenStateZeroPoint = hiddentensor->GetInfo().GetQuantizationOffset(); + if (operatorPtr->intermediates.size() > 4) + { + auto hiddentensor = CreateConstTensorPtr(subgraphPtr->tensors[operatorPtr->intermediates[4]].get(), + inputTensorInfo).first; + desc.m_HiddenStateScale = hiddentensor->GetInfo().GetQuantizationScale(); + desc.m_HiddenStateZeroPoint = hiddentensor->GetInfo().GetQuantizationOffset(); + } unsigned int batchSize = inputTensorInfo.GetShape()[0]; unsigned int outputSize = outputTensorInfo.GetShape()[2]; unsigned int numUnits = cellStateInInfo.GetShape()[1]; -- cgit v1.2.1