From 1299496996bc332f02218f926640a9255ed60310 Mon Sep 17 00:00:00 2001 From: Mike Kelly Date: Thu, 21 Apr 2022 11:57:09 +0100 Subject: IVGCVSW-6806 Add Unidirectional Sequence Lstm support to Neon * Corrected TensorInfo order for IsUnidirectionalSequenceLstmSupported * outputStateOut TensorInfo is not optional. * cellStateOut TensorInfo is not optional. * TensorInfo Order matches other QLSTM/LSTM layers. * Added missing parameters to UnidirectionalSequenceLstmOperator for delegate. * Added quantized UnidirectionalSequenceLstm support to Neon !android-nn-driver:7457 Signed-off-by: Mike Kelly Change-Id: I26dde1bb96793dd25eb9081ca5ae5f63752288c4 --- .../reference/workloads/RefUnidirectionalSequenceLstmWorkload.cpp | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) (limited to 'src/backends/reference/workloads') diff --git a/src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.cpp b/src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.cpp index d447a46b23..c4345d4978 100644 --- a/src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.cpp +++ b/src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.cpp @@ -59,7 +59,9 @@ void RefUnidirectionalSequenceLstmWorkload::Execute(std::vector TensorInfo inputInfo = GetTensorInfo(inputs[0]); const TensorInfo& outputStateInfo = GetTensorInfo(inputs[1]); const TensorInfo& cellStateInfo = GetTensorInfo(inputs[2]); - TensorInfo outputInfo = GetTensorInfo(outputs[0]); + TensorInfo outputStateOutInfo = GetTensorInfo(outputs[0]); + TensorInfo cellStateOutInfo = GetTensorInfo(outputs[1]); + TensorInfo outputInfo = GetTensorInfo(outputs[2]); TensorShape& inputShape = inputInfo.GetShape(); TensorShape& outputShape= outputInfo.GetShape(); auto inputTensor = reinterpret_cast(inputs[0]->Map()); @@ -140,7 +142,7 @@ void RefUnidirectionalSequenceLstmWorkload::Execute(std::vector auto currentInputData = reinterpret_cast(inputs[0]->Map()); std::unique_ptr> inputData = MakeDecoder(lstmInputInfo, currentInputData); - auto currentOutputData = reinterpret_cast(outputs[0]->Map()); + auto currentOutputData = reinterpret_cast(outputs[2]->Map()); std::unique_ptr> output = MakeEncoder(lstmOutputInfo, currentOutputData); std::unique_ptr> outputDecoder = MakeDecoder(lstmOutputInfo, currentOutputData); @@ -296,7 +298,7 @@ void RefUnidirectionalSequenceLstmWorkload::Execute(std::vector { // Permute Output back to batch major const PermutationVector& mappings = {1U, 0U, 2U}; - auto outputData = reinterpret_cast(outputs[0]->Map()); + auto outputData = reinterpret_cast(outputs[2]->Map()); std::vector outputValue(outputData, outputData + outputInfo.GetNumElements()); outputShape = armnnUtils::Permuted(outputInfo.GetShape(), mappings); outputInfo.SetShape(outputShape); -- cgit v1.2.1