diff options
Diffstat (limited to 'delegate/src/UnidirectionalSequenceLstm.hpp')
-rw-r--r-- | delegate/src/UnidirectionalSequenceLstm.hpp | 49 |
1 files changed, 41 insertions, 8 deletions
diff --git a/delegate/src/UnidirectionalSequenceLstm.hpp b/delegate/src/UnidirectionalSequenceLstm.hpp index a923874a74..bcf01cf2a9 100644 --- a/delegate/src/UnidirectionalSequenceLstm.hpp +++ b/delegate/src/UnidirectionalSequenceLstm.hpp @@ -151,6 +151,36 @@ TfLiteStatus VisitUnidirectionalSequenceLstmOperator(DelegateData& delegateData, || params.m_OutputLayerNormWeights != nullptr); desc.m_TimeMajor = nodeParams->time_major; + if (tfLiteNode->intermediates->size > 3 && desc.m_LayerNormEnabled) + { + auto inputIntermediateTensorInfo = GetTensorInfoForTfLiteTensor( + tfLiteTensors[tfLiteNode->intermediates->data[0]]); + auto forgetIntermediateTensorInfo = GetTensorInfoForTfLiteTensor( + tfLiteTensors[tfLiteNode->intermediates->data[1]]); + auto cellIntermediateTensorInfo = GetTensorInfoForTfLiteTensor( + tfLiteTensors[tfLiteNode->intermediates->data[2]]); + auto outputIntermediateTensorInfo = GetTensorInfoForTfLiteTensor( + tfLiteTensors[tfLiteNode->intermediates->data[3]]); + + desc.m_InputIntermediateScale = inputIntermediateTensorInfo.GetQuantizationScale(); + desc.m_ForgetIntermediateScale = forgetIntermediateTensorInfo.GetQuantizationScale(); + desc.m_CellIntermediateScale = cellIntermediateTensorInfo.GetQuantizationScale(); + desc.m_OutputIntermediateScale = outputIntermediateTensorInfo.GetQuantizationScale(); + } + else + { + float defaultIntermediate = std::pow(2, -12); + desc.m_InputIntermediateScale = defaultIntermediate; + desc.m_ForgetIntermediateScale = defaultIntermediate; + desc.m_CellIntermediateScale = defaultIntermediate; + desc.m_OutputIntermediateScale = defaultIntermediate; + } + if (tfLiteNode->intermediates->size > 4) + { + auto hiddentensorInfo = GetTensorInfoForTfLiteTensor(tfLiteTensors[tfLiteNode->intermediates->data[4]]); + desc.m_HiddenStateScale = hiddentensorInfo.GetQuantizationScale(); + desc.m_HiddenStateZeroPoint = hiddentensorInfo.GetQuantizationOffset(); + } const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor); const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor); @@ -167,7 +197,11 @@ TfLiteStatus VisitUnidirectionalSequenceLstmOperator(DelegateData& delegateData, { scratchBufferTensorInfo = armnn::TensorInfo({batchSize, numUnits * 4}, dataType, qScale, qOffset); } - armnn::TensorInfo cellStateOutTensorInfo({batchSize, numUnits}, dataType, qScale, qOffset); + armnn::TensorInfo cellStateOutTensorInfo({batchSize, numUnits}, + cellStateInInfo.GetDataType(), + cellStateInInfo.GetQuantizationScale(), + cellStateInInfo.GetQuantizationOffset()); + armnn::TensorInfo outputStateOutTensorInfo({batchSize, outputSize}, dataType, qScale, qOffset); armnn::LstmInputParamsInfo paramsInfo; @@ -218,9 +252,6 @@ TfLiteStatus VisitUnidirectionalSequenceLstmOperator(DelegateData& delegateData, paramsInfo.m_OutputLayerNormWeights = &(params.m_OutputLayerNormWeights->GetInfo()); } - // hiddenStateOutput and cellStateOutput do not present in TfLite UnidirectionalSequenceLstm - armnn::Optional<armnn::TensorInfo> optionalTensor; - bool isSupported = false; auto validateFunc = [&](const armnn::TensorInfo& outputInfo, bool& isSupported) { @@ -232,9 +263,9 @@ TfLiteStatus VisitUnidirectionalSequenceLstmOperator(DelegateData& delegateData, inputTensorInfo, outputStateInInfo, cellStateInInfo, + outputStateOutTensorInfo, + cellStateOutTensorInfo, outputInfo, - optionalTensor, - optionalTensor, desc, paramsInfo); }; @@ -248,7 +279,9 @@ TfLiteStatus VisitUnidirectionalSequenceLstmOperator(DelegateData& delegateData, armnn::IConnectableLayer* layer = delegateData.m_Network->AddUnidirectionalSequenceLstmLayer(desc, params); ARMNN_ASSERT(layer != nullptr); - layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo); + layer->GetOutputSlot(0).SetTensorInfo(outputStateOutTensorInfo); + layer->GetOutputSlot(1).SetTensorInfo(cellStateOutTensorInfo); + layer->GetOutputSlot(2).SetTensorInfo(outputTensorInfo); // Connect the inputs // input_layer @@ -258,7 +291,7 @@ TfLiteStatus VisitUnidirectionalSequenceLstmOperator(DelegateData& delegateData, //outputStateIn delegateData.m_OutputSlotForNode[tfLiteNode->inputs->data[19]]->Connect(layer->GetInputSlot(2)); - armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(0); + armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(2); delegateData.m_OutputSlotForNode[static_cast<unsigned long>(tfLiteNode->outputs->data[0])] = &outputSlot; return kTfLiteOk; } |