aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Kelly <mike.kelly@arm.com>2022-06-15 10:57:52 +0100
committermike.kelly <mike.kelly@arm.com>2022-06-15 14:29:19 +0000
commitc0800a32634db3e78d92493402492ff8426a91de (patch)
tree78204c1eb043902e572b000ee664cb32d329dc3a
parentd21abaf5c9e899164484044e1e812739a779a6b8 (diff)
downloadarmnn-c0800a32634db3e78d92493402492ff8426a91de.tar.gz
GitHub 653: Segfault when parsing Unidirectional Sequence LSTM
* Fixed Segfault when parsing Unidirectional Sequence LSTM Signed-off-by: Mike Kelly <mike.kelly@arm.com> Change-Id: Ic69a4190c60ef595be64bc2c356e540319381b7e
-rw-r--r--src/armnnTfLiteParser/TfLiteParser.cpp14
1 files changed, 8 insertions, 6 deletions
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp
index 49f1f9f856..479fc4f474 100644
--- a/src/armnnTfLiteParser/TfLiteParser.cpp
+++ b/src/armnnTfLiteParser/TfLiteParser.cpp
@@ -3346,7 +3346,7 @@ void TfLiteParserImpl::ParseUnidirectionalSequenceLSTM(size_t subgraphIndex, siz
|| params.m_OutputLayerNormWeights != nullptr);
desc.m_TimeMajor = nodeParams->time_major;
- if (desc.m_LayerNormEnabled)
+ if (operatorPtr->intermediates.size() > 3 && desc.m_LayerNormEnabled)
{
auto inputIntermediate = CreateConstTensorPtr(subgraphPtr->tensors[operatorPtr->intermediates[0]].get(),
inputTensorInfo).first;
@@ -3377,12 +3377,14 @@ void TfLiteParserImpl::ParseUnidirectionalSequenceLSTM(size_t subgraphIndex, siz
desc.m_OutputIntermediateScale = defaultIntermediate;
}
- auto hiddentensor = CreateConstTensorPtr(subgraphPtr->tensors[operatorPtr->intermediates[4]].get(),
- inputTensorInfo).first;
-
- desc.m_HiddenStateScale = hiddentensor->GetInfo().GetQuantizationScale();
- desc.m_HiddenStateZeroPoint = hiddentensor->GetInfo().GetQuantizationOffset();
+ if (operatorPtr->intermediates.size() > 4)
+ {
+ auto hiddentensor = CreateConstTensorPtr(subgraphPtr->tensors[operatorPtr->intermediates[4]].get(),
+ inputTensorInfo).first;
+ desc.m_HiddenStateScale = hiddentensor->GetInfo().GetQuantizationScale();
+ desc.m_HiddenStateZeroPoint = hiddentensor->GetInfo().GetQuantizationOffset();
+ }
unsigned int batchSize = inputTensorInfo.GetShape()[0];
unsigned int outputSize = outputTensorInfo.GetShape()[2];
unsigned int numUnits = cellStateInInfo.GetShape()[1];