aboutsummaryrefslogtreecommitdiff
path: root/delegate/src/UnidirectionalSequenceLstm.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'delegate/src/UnidirectionalSequenceLstm.hpp')
-rw-r--r--delegate/src/UnidirectionalSequenceLstm.hpp49
1 files changed, 41 insertions, 8 deletions
diff --git a/delegate/src/UnidirectionalSequenceLstm.hpp b/delegate/src/UnidirectionalSequenceLstm.hpp
index a923874a74..bcf01cf2a9 100644
--- a/delegate/src/UnidirectionalSequenceLstm.hpp
+++ b/delegate/src/UnidirectionalSequenceLstm.hpp
@@ -151,6 +151,36 @@ TfLiteStatus VisitUnidirectionalSequenceLstmOperator(DelegateData& delegateData,
|| params.m_OutputLayerNormWeights != nullptr);
desc.m_TimeMajor = nodeParams->time_major;
+ if (tfLiteNode->intermediates->size > 3 && desc.m_LayerNormEnabled)
+ {
+ auto inputIntermediateTensorInfo = GetTensorInfoForTfLiteTensor(
+ tfLiteTensors[tfLiteNode->intermediates->data[0]]);
+ auto forgetIntermediateTensorInfo = GetTensorInfoForTfLiteTensor(
+ tfLiteTensors[tfLiteNode->intermediates->data[1]]);
+ auto cellIntermediateTensorInfo = GetTensorInfoForTfLiteTensor(
+ tfLiteTensors[tfLiteNode->intermediates->data[2]]);
+ auto outputIntermediateTensorInfo = GetTensorInfoForTfLiteTensor(
+ tfLiteTensors[tfLiteNode->intermediates->data[3]]);
+
+ desc.m_InputIntermediateScale = inputIntermediateTensorInfo.GetQuantizationScale();
+ desc.m_ForgetIntermediateScale = forgetIntermediateTensorInfo.GetQuantizationScale();
+ desc.m_CellIntermediateScale = cellIntermediateTensorInfo.GetQuantizationScale();
+ desc.m_OutputIntermediateScale = outputIntermediateTensorInfo.GetQuantizationScale();
+ }
+ else
+ {
+ float defaultIntermediate = std::pow(2, -12);
+ desc.m_InputIntermediateScale = defaultIntermediate;
+ desc.m_ForgetIntermediateScale = defaultIntermediate;
+ desc.m_CellIntermediateScale = defaultIntermediate;
+ desc.m_OutputIntermediateScale = defaultIntermediate;
+ }
+ if (tfLiteNode->intermediates->size > 4)
+ {
+ auto hiddentensorInfo = GetTensorInfoForTfLiteTensor(tfLiteTensors[tfLiteNode->intermediates->data[4]]);
+ desc.m_HiddenStateScale = hiddentensorInfo.GetQuantizationScale();
+ desc.m_HiddenStateZeroPoint = hiddentensorInfo.GetQuantizationOffset();
+ }
const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor);
@@ -167,7 +197,11 @@ TfLiteStatus VisitUnidirectionalSequenceLstmOperator(DelegateData& delegateData,
{
scratchBufferTensorInfo = armnn::TensorInfo({batchSize, numUnits * 4}, dataType, qScale, qOffset);
}
- armnn::TensorInfo cellStateOutTensorInfo({batchSize, numUnits}, dataType, qScale, qOffset);
+ armnn::TensorInfo cellStateOutTensorInfo({batchSize, numUnits},
+ cellStateInInfo.GetDataType(),
+ cellStateInInfo.GetQuantizationScale(),
+ cellStateInInfo.GetQuantizationOffset());
+
armnn::TensorInfo outputStateOutTensorInfo({batchSize, outputSize}, dataType, qScale, qOffset);
armnn::LstmInputParamsInfo paramsInfo;
@@ -218,9 +252,6 @@ TfLiteStatus VisitUnidirectionalSequenceLstmOperator(DelegateData& delegateData,
paramsInfo.m_OutputLayerNormWeights = &(params.m_OutputLayerNormWeights->GetInfo());
}
- // hiddenStateOutput and cellStateOutput do not present in TfLite UnidirectionalSequenceLstm
- armnn::Optional<armnn::TensorInfo> optionalTensor;
-
bool isSupported = false;
auto validateFunc = [&](const armnn::TensorInfo& outputInfo, bool& isSupported)
{
@@ -232,9 +263,9 @@ TfLiteStatus VisitUnidirectionalSequenceLstmOperator(DelegateData& delegateData,
inputTensorInfo,
outputStateInInfo,
cellStateInInfo,
+ outputStateOutTensorInfo,
+ cellStateOutTensorInfo,
outputInfo,
- optionalTensor,
- optionalTensor,
desc,
paramsInfo);
};
@@ -248,7 +279,9 @@ TfLiteStatus VisitUnidirectionalSequenceLstmOperator(DelegateData& delegateData,
armnn::IConnectableLayer* layer = delegateData.m_Network->AddUnidirectionalSequenceLstmLayer(desc, params);
ARMNN_ASSERT(layer != nullptr);
- layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
+ layer->GetOutputSlot(0).SetTensorInfo(outputStateOutTensorInfo);
+ layer->GetOutputSlot(1).SetTensorInfo(cellStateOutTensorInfo);
+ layer->GetOutputSlot(2).SetTensorInfo(outputTensorInfo);
// Connect the inputs
// input_layer
@@ -258,7 +291,7 @@ TfLiteStatus VisitUnidirectionalSequenceLstmOperator(DelegateData& delegateData,
//outputStateIn
delegateData.m_OutputSlotForNode[tfLiteNode->inputs->data[19]]->Connect(layer->GetInputSlot(2));
- armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(0);
+ armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(2);
delegateData.m_OutputSlotForNode[static_cast<unsigned long>(tfLiteNode->outputs->data[0])] = &outputSlot;
return kTfLiteOk;
}