diff options
author | Mike Kelly <mike.kelly@arm.com> | 2022-04-21 11:57:09 +0100 |
---|---|---|
committer | mike.kelly <mike.kelly@arm.com> | 2022-05-05 08:29:20 +0000 |
commit | 1299496996bc332f02218f926640a9255ed60310 (patch) | |
tree | 2d242e142bd8fe7387140bcf8cdf39cd13ffc9eb /src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.cpp | |
parent | 8272a7bda2974c39b6c06e3eb3a000f2bdb749f7 (diff) | |
download | armnn-1299496996bc332f02218f926640a9255ed60310.tar.gz |
IVGCVSW-6806 Add Unidirectional Sequence Lstm support to Neon
* Corrected TensorInfo order for IsUnidirectionalSequenceLstmSupported
* outputStateOut TensorInfo is not optional.
* cellStateOut TensorInfo is not optional.
* TensorInfo Order matches other QLSTM/LSTM layers.
* Added missing parameters to UnidirectionalSequenceLstmOperator for
delegate.
* Added quantized UnidirectionalSequenceLstm support to Neon
!android-nn-driver:7457
Signed-off-by: Mike Kelly <mike.kelly@arm.com>
Change-Id: I26dde1bb96793dd25eb9081ca5ae5f63752288c4
Diffstat (limited to 'src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.cpp')
-rw-r--r-- | src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.cpp | 8 |
1 files changed, 5 insertions, 3 deletions
diff --git a/src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.cpp b/src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.cpp index d447a46b23..c4345d4978 100644 --- a/src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.cpp +++ b/src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.cpp @@ -59,7 +59,9 @@ void RefUnidirectionalSequenceLstmWorkload::Execute(std::vector<ITensorHandle*> TensorInfo inputInfo = GetTensorInfo(inputs[0]); const TensorInfo& outputStateInfo = GetTensorInfo(inputs[1]); const TensorInfo& cellStateInfo = GetTensorInfo(inputs[2]); - TensorInfo outputInfo = GetTensorInfo(outputs[0]); + TensorInfo outputStateOutInfo = GetTensorInfo(outputs[0]); + TensorInfo cellStateOutInfo = GetTensorInfo(outputs[1]); + TensorInfo outputInfo = GetTensorInfo(outputs[2]); TensorShape& inputShape = inputInfo.GetShape(); TensorShape& outputShape= outputInfo.GetShape(); auto inputTensor = reinterpret_cast<float*>(inputs[0]->Map()); @@ -140,7 +142,7 @@ void RefUnidirectionalSequenceLstmWorkload::Execute(std::vector<ITensorHandle*> auto currentInputData = reinterpret_cast<float*>(inputs[0]->Map()); std::unique_ptr<Decoder<float>> inputData = MakeDecoder<float>(lstmInputInfo, currentInputData); - auto currentOutputData = reinterpret_cast<float*>(outputs[0]->Map()); + auto currentOutputData = reinterpret_cast<float*>(outputs[2]->Map()); std::unique_ptr<Encoder<float>> output = MakeEncoder<float>(lstmOutputInfo, currentOutputData); std::unique_ptr<Decoder<float>> outputDecoder = MakeDecoder<float>(lstmOutputInfo, currentOutputData); @@ -296,7 +298,7 @@ void RefUnidirectionalSequenceLstmWorkload::Execute(std::vector<ITensorHandle*> { // Permute Output back to batch major const PermutationVector& mappings = {1U, 0U, 2U}; - auto outputData = reinterpret_cast<float*>(outputs[0]->Map()); + auto outputData = reinterpret_cast<float*>(outputs[2]->Map()); std::vector<float> outputValue(outputData, outputData + outputInfo.GetNumElements()); outputShape = armnnUtils::Permuted(outputInfo.GetShape(), mappings); outputInfo.SetShape(outputShape); |