diff options
author | Mike Kelly <mike.kelly@arm.com> | 2022-06-15 10:57:52 +0100 |
---|---|---|
committer | mike.kelly <mike.kelly@arm.com> | 2022-06-15 14:29:19 +0000 |
commit | c0800a32634db3e78d92493402492ff8426a91de (patch) | |
tree | 78204c1eb043902e572b000ee664cb32d329dc3a | |
parent | d21abaf5c9e899164484044e1e812739a779a6b8 (diff) | |
download | armnn-c0800a32634db3e78d92493402492ff8426a91de.tar.gz |
GitHub 653: Segfault when parsing Unidirectional Sequence LSTM
* Fixed Segfault when parsing Unidirectional Sequence LSTM
Signed-off-by: Mike Kelly <mike.kelly@arm.com>
Change-Id: Ic69a4190c60ef595be64bc2c356e540319381b7e
-rw-r--r-- | src/armnnTfLiteParser/TfLiteParser.cpp | 14 |
1 files changed, 8 insertions, 6 deletions
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp index 49f1f9f856..479fc4f474 100644 --- a/src/armnnTfLiteParser/TfLiteParser.cpp +++ b/src/armnnTfLiteParser/TfLiteParser.cpp @@ -3346,7 +3346,7 @@ void TfLiteParserImpl::ParseUnidirectionalSequenceLSTM(size_t subgraphIndex, siz || params.m_OutputLayerNormWeights != nullptr); desc.m_TimeMajor = nodeParams->time_major; - if (desc.m_LayerNormEnabled) + if (operatorPtr->intermediates.size() > 3 && desc.m_LayerNormEnabled) { auto inputIntermediate = CreateConstTensorPtr(subgraphPtr->tensors[operatorPtr->intermediates[0]].get(), inputTensorInfo).first; @@ -3377,12 +3377,14 @@ void TfLiteParserImpl::ParseUnidirectionalSequenceLSTM(size_t subgraphIndex, siz desc.m_OutputIntermediateScale = defaultIntermediate; } - auto hiddentensor = CreateConstTensorPtr(subgraphPtr->tensors[operatorPtr->intermediates[4]].get(), - inputTensorInfo).first; - - desc.m_HiddenStateScale = hiddentensor->GetInfo().GetQuantizationScale(); - desc.m_HiddenStateZeroPoint = hiddentensor->GetInfo().GetQuantizationOffset(); + if (operatorPtr->intermediates.size() > 4) + { + auto hiddentensor = CreateConstTensorPtr(subgraphPtr->tensors[operatorPtr->intermediates[4]].get(), + inputTensorInfo).first; + desc.m_HiddenStateScale = hiddentensor->GetInfo().GetQuantizationScale(); + desc.m_HiddenStateZeroPoint = hiddentensor->GetInfo().GetQuantizationOffset(); + } unsigned int batchSize = inputTensorInfo.GetShape()[0]; unsigned int outputSize = outputTensorInfo.GetShape()[2]; unsigned int numUnits = cellStateInInfo.GetShape()[1]; |