aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.cpp
diff options
context:
space:
mode:
authorMike Kelly <mike.kelly@arm.com>2022-04-21 11:57:09 +0100
committermike.kelly <mike.kelly@arm.com>2022-05-05 08:29:20 +0000
commit1299496996bc332f02218f926640a9255ed60310 (patch)
tree2d242e142bd8fe7387140bcf8cdf39cd13ffc9eb /src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.cpp
parent8272a7bda2974c39b6c06e3eb3a000f2bdb749f7 (diff)
downloadarmnn-1299496996bc332f02218f926640a9255ed60310.tar.gz
IVGCVSW-6806 Add Unidirectional Sequence Lstm support to Neon
* Corrected TensorInfo order for IsUnidirectionalSequenceLstmSupported * outputStateOut TensorInfo is not optional. * cellStateOut TensorInfo is not optional. * TensorInfo Order matches other QLSTM/LSTM layers. * Added missing parameters to UnidirectionalSequenceLstmOperator for delegate. * Added quantized UnidirectionalSequenceLstm support to Neon !android-nn-driver:7457 Signed-off-by: Mike Kelly <mike.kelly@arm.com> Change-Id: I26dde1bb96793dd25eb9081ca5ae5f63752288c4
Diffstat (limited to 'src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.cpp')
-rw-r--r--src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.cpp8
1 files changed, 5 insertions, 3 deletions
diff --git a/src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.cpp b/src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.cpp
index d447a46b23..c4345d4978 100644
--- a/src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.cpp
+++ b/src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.cpp
@@ -59,7 +59,9 @@ void RefUnidirectionalSequenceLstmWorkload::Execute(std::vector<ITensorHandle*>
TensorInfo inputInfo = GetTensorInfo(inputs[0]);
const TensorInfo& outputStateInfo = GetTensorInfo(inputs[1]);
const TensorInfo& cellStateInfo = GetTensorInfo(inputs[2]);
- TensorInfo outputInfo = GetTensorInfo(outputs[0]);
+ TensorInfo outputStateOutInfo = GetTensorInfo(outputs[0]);
+ TensorInfo cellStateOutInfo = GetTensorInfo(outputs[1]);
+ TensorInfo outputInfo = GetTensorInfo(outputs[2]);
TensorShape& inputShape = inputInfo.GetShape();
TensorShape& outputShape= outputInfo.GetShape();
auto inputTensor = reinterpret_cast<float*>(inputs[0]->Map());
@@ -140,7 +142,7 @@ void RefUnidirectionalSequenceLstmWorkload::Execute(std::vector<ITensorHandle*>
auto currentInputData = reinterpret_cast<float*>(inputs[0]->Map());
std::unique_ptr<Decoder<float>> inputData = MakeDecoder<float>(lstmInputInfo, currentInputData);
- auto currentOutputData = reinterpret_cast<float*>(outputs[0]->Map());
+ auto currentOutputData = reinterpret_cast<float*>(outputs[2]->Map());
std::unique_ptr<Encoder<float>> output = MakeEncoder<float>(lstmOutputInfo, currentOutputData);
std::unique_ptr<Decoder<float>> outputDecoder = MakeDecoder<float>(lstmOutputInfo, currentOutputData);
@@ -296,7 +298,7 @@ void RefUnidirectionalSequenceLstmWorkload::Execute(std::vector<ITensorHandle*>
{
// Permute Output back to batch major
const PermutationVector& mappings = {1U, 0U, 2U};
- auto outputData = reinterpret_cast<float*>(outputs[0]->Map());
+ auto outputData = reinterpret_cast<float*>(outputs[2]->Map());
std::vector<float> outputValue(outputData, outputData + outputInfo.GetNumElements());
outputShape = armnnUtils::Permuted(outputInfo.GetShape(), mappings);
outputInfo.SetShape(outputShape);