diff options
Diffstat (limited to 'src/armnnTfLiteParser')
-rw-r--r-- | src/armnnTfLiteParser/TfLiteParser.cpp | 14 |
1 files changed, 8 insertions, 6 deletions
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp index 49f1f9f856..479fc4f474 100644 --- a/src/armnnTfLiteParser/TfLiteParser.cpp +++ b/src/armnnTfLiteParser/TfLiteParser.cpp @@ -3346,7 +3346,7 @@ void TfLiteParserImpl::ParseUnidirectionalSequenceLSTM(size_t subgraphIndex, siz || params.m_OutputLayerNormWeights != nullptr); desc.m_TimeMajor = nodeParams->time_major; - if (desc.m_LayerNormEnabled) + if (operatorPtr->intermediates.size() > 3 && desc.m_LayerNormEnabled) { auto inputIntermediate = CreateConstTensorPtr(subgraphPtr->tensors[operatorPtr->intermediates[0]].get(), inputTensorInfo).first; @@ -3377,12 +3377,14 @@ void TfLiteParserImpl::ParseUnidirectionalSequenceLSTM(size_t subgraphIndex, siz desc.m_OutputIntermediateScale = defaultIntermediate; } - auto hiddentensor = CreateConstTensorPtr(subgraphPtr->tensors[operatorPtr->intermediates[4]].get(), - inputTensorInfo).first; - - desc.m_HiddenStateScale = hiddentensor->GetInfo().GetQuantizationScale(); - desc.m_HiddenStateZeroPoint = hiddentensor->GetInfo().GetQuantizationOffset(); + if (operatorPtr->intermediates.size() > 4) + { + auto hiddentensor = CreateConstTensorPtr(subgraphPtr->tensors[operatorPtr->intermediates[4]].get(), + inputTensorInfo).first; + desc.m_HiddenStateScale = hiddentensor->GetInfo().GetQuantizationScale(); + desc.m_HiddenStateZeroPoint = hiddentensor->GetInfo().GetQuantizationOffset(); + } unsigned int batchSize = inputTensorInfo.GetShape()[0]; unsigned int outputSize = outputTensorInfo.GetShape()[2]; unsigned int numUnits = cellStateInInfo.GetShape()[1]; |