aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.cpp')
-rw-r--r--src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.cpp8
1 files changed, 5 insertions, 3 deletions
diff --git a/src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.cpp b/src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.cpp
index d447a46b23..c4345d4978 100644
--- a/src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.cpp
+++ b/src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.cpp
@@ -59,7 +59,9 @@ void RefUnidirectionalSequenceLstmWorkload::Execute(std::vector<ITensorHandle*>
TensorInfo inputInfo = GetTensorInfo(inputs[0]);
const TensorInfo& outputStateInfo = GetTensorInfo(inputs[1]);
const TensorInfo& cellStateInfo = GetTensorInfo(inputs[2]);
- TensorInfo outputInfo = GetTensorInfo(outputs[0]);
+ TensorInfo outputStateOutInfo = GetTensorInfo(outputs[0]);
+ TensorInfo cellStateOutInfo = GetTensorInfo(outputs[1]);
+ TensorInfo outputInfo = GetTensorInfo(outputs[2]);
TensorShape& inputShape = inputInfo.GetShape();
TensorShape& outputShape= outputInfo.GetShape();
auto inputTensor = reinterpret_cast<float*>(inputs[0]->Map());
@@ -140,7 +142,7 @@ void RefUnidirectionalSequenceLstmWorkload::Execute(std::vector<ITensorHandle*>
auto currentInputData = reinterpret_cast<float*>(inputs[0]->Map());
std::unique_ptr<Decoder<float>> inputData = MakeDecoder<float>(lstmInputInfo, currentInputData);
- auto currentOutputData = reinterpret_cast<float*>(outputs[0]->Map());
+ auto currentOutputData = reinterpret_cast<float*>(outputs[2]->Map());
std::unique_ptr<Encoder<float>> output = MakeEncoder<float>(lstmOutputInfo, currentOutputData);
std::unique_ptr<Decoder<float>> outputDecoder = MakeDecoder<float>(lstmOutputInfo, currentOutputData);
@@ -296,7 +298,7 @@ void RefUnidirectionalSequenceLstmWorkload::Execute(std::vector<ITensorHandle*>
{
// Permute Output back to batch major
const PermutationVector& mappings = {1U, 0U, 2U};
- auto outputData = reinterpret_cast<float*>(outputs[0]->Map());
+ auto outputData = reinterpret_cast<float*>(outputs[2]->Map());
std::vector<float> outputValue(outputData, outputData + outputInfo.GetNumElements());
outputShape = armnnUtils::Permuted(outputInfo.GetShape(), mappings);
outputInfo.SetShape(outputShape);