diff options
author | Mike Kelly <mike.kelly@arm.com> | 2022-04-21 11:57:09 +0100 |
---|---|---|
committer | mike.kelly <mike.kelly@arm.com> | 2022-05-05 08:29:20 +0000 |
commit | 1299496996bc332f02218f926640a9255ed60310 (patch) | |
tree | 2d242e142bd8fe7387140bcf8cdf39cd13ffc9eb /delegate | |
parent | 8272a7bda2974c39b6c06e3eb3a000f2bdb749f7 (diff) | |
download | armnn-1299496996bc332f02218f926640a9255ed60310.tar.gz |
IVGCVSW-6806 Add Unidirectional Sequence Lstm support to Neon
* Corrected TensorInfo order for IsUnidirectionalSequenceLstmSupported
* outputStateOut TensorInfo is not optional.
* cellStateOut TensorInfo is not optional.
* TensorInfo Order matches other QLSTM/LSTM layers.
* Added missing parameters to UnidirectionalSequenceLstmOperator for
delegate.
* Added quantized UnidirectionalSequenceLstm support to Neon
!android-nn-driver:7457
Signed-off-by: Mike Kelly <mike.kelly@arm.com>
Change-Id: I26dde1bb96793dd25eb9081ca5ae5f63752288c4
Diffstat (limited to 'delegate')
-rw-r--r-- | delegate/src/UnidirectionalSequenceLstm.hpp | 49 |
1 files changed, 41 insertions, 8 deletions
diff --git a/delegate/src/UnidirectionalSequenceLstm.hpp b/delegate/src/UnidirectionalSequenceLstm.hpp index a923874a74..bcf01cf2a9 100644 --- a/delegate/src/UnidirectionalSequenceLstm.hpp +++ b/delegate/src/UnidirectionalSequenceLstm.hpp @@ -151,6 +151,36 @@ TfLiteStatus VisitUnidirectionalSequenceLstmOperator(DelegateData& delegateData, || params.m_OutputLayerNormWeights != nullptr); desc.m_TimeMajor = nodeParams->time_major; + if (tfLiteNode->intermediates->size > 3 && desc.m_LayerNormEnabled) + { + auto inputIntermediateTensorInfo = GetTensorInfoForTfLiteTensor( + tfLiteTensors[tfLiteNode->intermediates->data[0]]); + auto forgetIntermediateTensorInfo = GetTensorInfoForTfLiteTensor( + tfLiteTensors[tfLiteNode->intermediates->data[1]]); + auto cellIntermediateTensorInfo = GetTensorInfoForTfLiteTensor( + tfLiteTensors[tfLiteNode->intermediates->data[2]]); + auto outputIntermediateTensorInfo = GetTensorInfoForTfLiteTensor( + tfLiteTensors[tfLiteNode->intermediates->data[3]]); + + desc.m_InputIntermediateScale = inputIntermediateTensorInfo.GetQuantizationScale(); + desc.m_ForgetIntermediateScale = forgetIntermediateTensorInfo.GetQuantizationScale(); + desc.m_CellIntermediateScale = cellIntermediateTensorInfo.GetQuantizationScale(); + desc.m_OutputIntermediateScale = outputIntermediateTensorInfo.GetQuantizationScale(); + } + else + { + float defaultIntermediate = std::pow(2, -12); + desc.m_InputIntermediateScale = defaultIntermediate; + desc.m_ForgetIntermediateScale = defaultIntermediate; + desc.m_CellIntermediateScale = defaultIntermediate; + desc.m_OutputIntermediateScale = defaultIntermediate; + } + if (tfLiteNode->intermediates->size > 4) + { + auto hiddentensorInfo = GetTensorInfoForTfLiteTensor(tfLiteTensors[tfLiteNode->intermediates->data[4]]); + desc.m_HiddenStateScale = hiddentensorInfo.GetQuantizationScale(); + desc.m_HiddenStateZeroPoint = hiddentensorInfo.GetQuantizationOffset(); + } const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor); const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor); @@ -167,7 +197,11 @@ TfLiteStatus VisitUnidirectionalSequenceLstmOperator(DelegateData& delegateData, { scratchBufferTensorInfo = armnn::TensorInfo({batchSize, numUnits * 4}, dataType, qScale, qOffset); } - armnn::TensorInfo cellStateOutTensorInfo({batchSize, numUnits}, dataType, qScale, qOffset); + armnn::TensorInfo cellStateOutTensorInfo({batchSize, numUnits}, + cellStateInInfo.GetDataType(), + cellStateInInfo.GetQuantizationScale(), + cellStateInInfo.GetQuantizationOffset()); + armnn::TensorInfo outputStateOutTensorInfo({batchSize, outputSize}, dataType, qScale, qOffset); armnn::LstmInputParamsInfo paramsInfo; @@ -218,9 +252,6 @@ TfLiteStatus VisitUnidirectionalSequenceLstmOperator(DelegateData& delegateData, paramsInfo.m_OutputLayerNormWeights = &(params.m_OutputLayerNormWeights->GetInfo()); } - // hiddenStateOutput and cellStateOutput do not present in TfLite UnidirectionalSequenceLstm - armnn::Optional<armnn::TensorInfo> optionalTensor; - bool isSupported = false; auto validateFunc = [&](const armnn::TensorInfo& outputInfo, bool& isSupported) { @@ -232,9 +263,9 @@ TfLiteStatus VisitUnidirectionalSequenceLstmOperator(DelegateData& delegateData, inputTensorInfo, outputStateInInfo, cellStateInInfo, + outputStateOutTensorInfo, + cellStateOutTensorInfo, outputInfo, - optionalTensor, - optionalTensor, desc, paramsInfo); }; @@ -248,7 +279,9 @@ TfLiteStatus VisitUnidirectionalSequenceLstmOperator(DelegateData& delegateData, armnn::IConnectableLayer* layer = delegateData.m_Network->AddUnidirectionalSequenceLstmLayer(desc, params); ARMNN_ASSERT(layer != nullptr); - layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo); + layer->GetOutputSlot(0).SetTensorInfo(outputStateOutTensorInfo); + layer->GetOutputSlot(1).SetTensorInfo(cellStateOutTensorInfo); + layer->GetOutputSlot(2).SetTensorInfo(outputTensorInfo); // Connect the inputs // input_layer @@ -258,7 +291,7 @@ TfLiteStatus VisitUnidirectionalSequenceLstmOperator(DelegateData& delegateData, //outputStateIn delegateData.m_OutputSlotForNode[tfLiteNode->inputs->data[19]]->Connect(layer->GetInputSlot(2)); - armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(0); + armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(2); delegateData.m_OutputSlotForNode[static_cast<unsigned long>(tfLiteNode->outputs->data[0])] = &outputSlot; return kTfLiteOk; } |