From 1299496996bc332f02218f926640a9255ed60310 Mon Sep 17 00:00:00 2001 From: Mike Kelly Date: Thu, 21 Apr 2022 11:57:09 +0100 Subject: IVGCVSW-6806 Add Unidirectional Sequence Lstm support to Neon * Corrected TensorInfo order for IsUnidirectionalSequenceLstmSupported * outputStateOut TensorInfo is not optional. * cellStateOut TensorInfo is not optional. * TensorInfo Order matches other QLSTM/LSTM layers. * Added missing parameters to UnidirectionalSequenceLstmOperator for delegate. * Added quantized UnidirectionalSequenceLstm support to Neon !android-nn-driver:7457 Signed-off-by: Mike Kelly Change-Id: I26dde1bb96793dd25eb9081ca5ae5f63752288c4 --- delegate/src/UnidirectionalSequenceLstm.hpp | 49 ++++++++++++++++++++++++----- 1 file changed, 41 insertions(+), 8 deletions(-) (limited to 'delegate') diff --git a/delegate/src/UnidirectionalSequenceLstm.hpp b/delegate/src/UnidirectionalSequenceLstm.hpp index a923874a74..bcf01cf2a9 100644 --- a/delegate/src/UnidirectionalSequenceLstm.hpp +++ b/delegate/src/UnidirectionalSequenceLstm.hpp @@ -151,6 +151,36 @@ TfLiteStatus VisitUnidirectionalSequenceLstmOperator(DelegateData& delegateData, || params.m_OutputLayerNormWeights != nullptr); desc.m_TimeMajor = nodeParams->time_major; + if (tfLiteNode->intermediates->size > 3 && desc.m_LayerNormEnabled) + { + auto inputIntermediateTensorInfo = GetTensorInfoForTfLiteTensor( + tfLiteTensors[tfLiteNode->intermediates->data[0]]); + auto forgetIntermediateTensorInfo = GetTensorInfoForTfLiteTensor( + tfLiteTensors[tfLiteNode->intermediates->data[1]]); + auto cellIntermediateTensorInfo = GetTensorInfoForTfLiteTensor( + tfLiteTensors[tfLiteNode->intermediates->data[2]]); + auto outputIntermediateTensorInfo = GetTensorInfoForTfLiteTensor( + tfLiteTensors[tfLiteNode->intermediates->data[3]]); + + desc.m_InputIntermediateScale = inputIntermediateTensorInfo.GetQuantizationScale(); + desc.m_ForgetIntermediateScale = forgetIntermediateTensorInfo.GetQuantizationScale(); + desc.m_CellIntermediateScale = cellIntermediateTensorInfo.GetQuantizationScale(); + desc.m_OutputIntermediateScale = outputIntermediateTensorInfo.GetQuantizationScale(); + } + else + { + float defaultIntermediate = std::pow(2, -12); + desc.m_InputIntermediateScale = defaultIntermediate; + desc.m_ForgetIntermediateScale = defaultIntermediate; + desc.m_CellIntermediateScale = defaultIntermediate; + desc.m_OutputIntermediateScale = defaultIntermediate; + } + if (tfLiteNode->intermediates->size > 4) + { + auto hiddentensorInfo = GetTensorInfoForTfLiteTensor(tfLiteTensors[tfLiteNode->intermediates->data[4]]); + desc.m_HiddenStateScale = hiddentensorInfo.GetQuantizationScale(); + desc.m_HiddenStateZeroPoint = hiddentensorInfo.GetQuantizationOffset(); + } const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor); const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor); @@ -167,7 +197,11 @@ TfLiteStatus VisitUnidirectionalSequenceLstmOperator(DelegateData& delegateData, { scratchBufferTensorInfo = armnn::TensorInfo({batchSize, numUnits * 4}, dataType, qScale, qOffset); } - armnn::TensorInfo cellStateOutTensorInfo({batchSize, numUnits}, dataType, qScale, qOffset); + armnn::TensorInfo cellStateOutTensorInfo({batchSize, numUnits}, + cellStateInInfo.GetDataType(), + cellStateInInfo.GetQuantizationScale(), + cellStateInInfo.GetQuantizationOffset()); + armnn::TensorInfo outputStateOutTensorInfo({batchSize, outputSize}, dataType, qScale, qOffset); armnn::LstmInputParamsInfo paramsInfo; @@ -218,9 +252,6 @@ TfLiteStatus VisitUnidirectionalSequenceLstmOperator(DelegateData& delegateData, paramsInfo.m_OutputLayerNormWeights = &(params.m_OutputLayerNormWeights->GetInfo()); } - // hiddenStateOutput and cellStateOutput do not present in TfLite UnidirectionalSequenceLstm - armnn::Optional optionalTensor; - bool isSupported = false; auto validateFunc = [&](const armnn::TensorInfo& outputInfo, bool& isSupported) { @@ -232,9 +263,9 @@ TfLiteStatus VisitUnidirectionalSequenceLstmOperator(DelegateData& delegateData, inputTensorInfo, outputStateInInfo, cellStateInInfo, + outputStateOutTensorInfo, + cellStateOutTensorInfo, outputInfo, - optionalTensor, - optionalTensor, desc, paramsInfo); }; @@ -248,7 +279,9 @@ TfLiteStatus VisitUnidirectionalSequenceLstmOperator(DelegateData& delegateData, armnn::IConnectableLayer* layer = delegateData.m_Network->AddUnidirectionalSequenceLstmLayer(desc, params); ARMNN_ASSERT(layer != nullptr); - layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo); + layer->GetOutputSlot(0).SetTensorInfo(outputStateOutTensorInfo); + layer->GetOutputSlot(1).SetTensorInfo(cellStateOutTensorInfo); + layer->GetOutputSlot(2).SetTensorInfo(outputTensorInfo); // Connect the inputs // input_layer @@ -258,7 +291,7 @@ TfLiteStatus VisitUnidirectionalSequenceLstmOperator(DelegateData& delegateData, //outputStateIn delegateData.m_OutputSlotForNode[tfLiteNode->inputs->data[19]]->Connect(layer->GetInputSlot(2)); - armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(0); + armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(2); delegateData.m_OutputSlotForNode[static_cast(tfLiteNode->outputs->data[0])] = &outputSlot; return kTfLiteOk; } -- cgit v1.2.1