aboutsummaryrefslogtreecommitdiff
path: root/delegate
diff options
context:
space:
mode:
authorMike Kelly <mike.kelly@arm.com>2022-04-21 11:57:09 +0100
committermike.kelly <mike.kelly@arm.com>2022-05-05 08:29:20 +0000
commit1299496996bc332f02218f926640a9255ed60310 (patch)
tree2d242e142bd8fe7387140bcf8cdf39cd13ffc9eb /delegate
parent8272a7bda2974c39b6c06e3eb3a000f2bdb749f7 (diff)
downloadarmnn-1299496996bc332f02218f926640a9255ed60310.tar.gz
IVGCVSW-6806 Add Unidirectional Sequence Lstm support to Neon
* Corrected TensorInfo order for IsUnidirectionalSequenceLstmSupported * outputStateOut TensorInfo is not optional. * cellStateOut TensorInfo is not optional. * TensorInfo Order matches other QLSTM/LSTM layers. * Added missing parameters to UnidirectionalSequenceLstmOperator for delegate. * Added quantized UnidirectionalSequenceLstm support to Neon !android-nn-driver:7457 Signed-off-by: Mike Kelly <mike.kelly@arm.com> Change-Id: I26dde1bb96793dd25eb9081ca5ae5f63752288c4
Diffstat (limited to 'delegate')
-rw-r--r--delegate/src/UnidirectionalSequenceLstm.hpp49
1 files changed, 41 insertions, 8 deletions
diff --git a/delegate/src/UnidirectionalSequenceLstm.hpp b/delegate/src/UnidirectionalSequenceLstm.hpp
index a923874a74..bcf01cf2a9 100644
--- a/delegate/src/UnidirectionalSequenceLstm.hpp
+++ b/delegate/src/UnidirectionalSequenceLstm.hpp
@@ -151,6 +151,36 @@ TfLiteStatus VisitUnidirectionalSequenceLstmOperator(DelegateData& delegateData,
|| params.m_OutputLayerNormWeights != nullptr);
desc.m_TimeMajor = nodeParams->time_major;
+ if (tfLiteNode->intermediates->size > 3 && desc.m_LayerNormEnabled)
+ {
+ auto inputIntermediateTensorInfo = GetTensorInfoForTfLiteTensor(
+ tfLiteTensors[tfLiteNode->intermediates->data[0]]);
+ auto forgetIntermediateTensorInfo = GetTensorInfoForTfLiteTensor(
+ tfLiteTensors[tfLiteNode->intermediates->data[1]]);
+ auto cellIntermediateTensorInfo = GetTensorInfoForTfLiteTensor(
+ tfLiteTensors[tfLiteNode->intermediates->data[2]]);
+ auto outputIntermediateTensorInfo = GetTensorInfoForTfLiteTensor(
+ tfLiteTensors[tfLiteNode->intermediates->data[3]]);
+
+ desc.m_InputIntermediateScale = inputIntermediateTensorInfo.GetQuantizationScale();
+ desc.m_ForgetIntermediateScale = forgetIntermediateTensorInfo.GetQuantizationScale();
+ desc.m_CellIntermediateScale = cellIntermediateTensorInfo.GetQuantizationScale();
+ desc.m_OutputIntermediateScale = outputIntermediateTensorInfo.GetQuantizationScale();
+ }
+ else
+ {
+ float defaultIntermediate = std::pow(2, -12);
+ desc.m_InputIntermediateScale = defaultIntermediate;
+ desc.m_ForgetIntermediateScale = defaultIntermediate;
+ desc.m_CellIntermediateScale = defaultIntermediate;
+ desc.m_OutputIntermediateScale = defaultIntermediate;
+ }
+ if (tfLiteNode->intermediates->size > 4)
+ {
+ auto hiddentensorInfo = GetTensorInfoForTfLiteTensor(tfLiteTensors[tfLiteNode->intermediates->data[4]]);
+ desc.m_HiddenStateScale = hiddentensorInfo.GetQuantizationScale();
+ desc.m_HiddenStateZeroPoint = hiddentensorInfo.GetQuantizationOffset();
+ }
const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor);
@@ -167,7 +197,11 @@ TfLiteStatus VisitUnidirectionalSequenceLstmOperator(DelegateData& delegateData,
{
scratchBufferTensorInfo = armnn::TensorInfo({batchSize, numUnits * 4}, dataType, qScale, qOffset);
}
- armnn::TensorInfo cellStateOutTensorInfo({batchSize, numUnits}, dataType, qScale, qOffset);
+ armnn::TensorInfo cellStateOutTensorInfo({batchSize, numUnits},
+ cellStateInInfo.GetDataType(),
+ cellStateInInfo.GetQuantizationScale(),
+ cellStateInInfo.GetQuantizationOffset());
+
armnn::TensorInfo outputStateOutTensorInfo({batchSize, outputSize}, dataType, qScale, qOffset);
armnn::LstmInputParamsInfo paramsInfo;
@@ -218,9 +252,6 @@ TfLiteStatus VisitUnidirectionalSequenceLstmOperator(DelegateData& delegateData,
paramsInfo.m_OutputLayerNormWeights = &(params.m_OutputLayerNormWeights->GetInfo());
}
- // hiddenStateOutput and cellStateOutput do not present in TfLite UnidirectionalSequenceLstm
- armnn::Optional<armnn::TensorInfo> optionalTensor;
-
bool isSupported = false;
auto validateFunc = [&](const armnn::TensorInfo& outputInfo, bool& isSupported)
{
@@ -232,9 +263,9 @@ TfLiteStatus VisitUnidirectionalSequenceLstmOperator(DelegateData& delegateData,
inputTensorInfo,
outputStateInInfo,
cellStateInInfo,
+ outputStateOutTensorInfo,
+ cellStateOutTensorInfo,
outputInfo,
- optionalTensor,
- optionalTensor,
desc,
paramsInfo);
};
@@ -248,7 +279,9 @@ TfLiteStatus VisitUnidirectionalSequenceLstmOperator(DelegateData& delegateData,
armnn::IConnectableLayer* layer = delegateData.m_Network->AddUnidirectionalSequenceLstmLayer(desc, params);
ARMNN_ASSERT(layer != nullptr);
- layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
+ layer->GetOutputSlot(0).SetTensorInfo(outputStateOutTensorInfo);
+ layer->GetOutputSlot(1).SetTensorInfo(cellStateOutTensorInfo);
+ layer->GetOutputSlot(2).SetTensorInfo(outputTensorInfo);
// Connect the inputs
// input_layer
@@ -258,7 +291,7 @@ TfLiteStatus VisitUnidirectionalSequenceLstmOperator(DelegateData& delegateData,
//outputStateIn
delegateData.m_OutputSlotForNode[tfLiteNode->inputs->data[19]]->Connect(layer->GetInputSlot(2));
- armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(0);
+ armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(2);
delegateData.m_OutputSlotForNode[static_cast<unsigned long>(tfLiteNode->outputs->data[0])] = &outputSlot;
return kTfLiteOk;
}