diff options
Diffstat (limited to 'src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.cpp')
-rw-r--r-- | src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.cpp | 8 |
1 files changed, 5 insertions, 3 deletions
diff --git a/src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.cpp b/src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.cpp index d447a46b23..c4345d4978 100644 --- a/src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.cpp +++ b/src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.cpp @@ -59,7 +59,9 @@ void RefUnidirectionalSequenceLstmWorkload::Execute(std::vector<ITensorHandle*> TensorInfo inputInfo = GetTensorInfo(inputs[0]); const TensorInfo& outputStateInfo = GetTensorInfo(inputs[1]); const TensorInfo& cellStateInfo = GetTensorInfo(inputs[2]); - TensorInfo outputInfo = GetTensorInfo(outputs[0]); + TensorInfo outputStateOutInfo = GetTensorInfo(outputs[0]); + TensorInfo cellStateOutInfo = GetTensorInfo(outputs[1]); + TensorInfo outputInfo = GetTensorInfo(outputs[2]); TensorShape& inputShape = inputInfo.GetShape(); TensorShape& outputShape= outputInfo.GetShape(); auto inputTensor = reinterpret_cast<float*>(inputs[0]->Map()); @@ -140,7 +142,7 @@ void RefUnidirectionalSequenceLstmWorkload::Execute(std::vector<ITensorHandle*> auto currentInputData = reinterpret_cast<float*>(inputs[0]->Map()); std::unique_ptr<Decoder<float>> inputData = MakeDecoder<float>(lstmInputInfo, currentInputData); - auto currentOutputData = reinterpret_cast<float*>(outputs[0]->Map()); + auto currentOutputData = reinterpret_cast<float*>(outputs[2]->Map()); std::unique_ptr<Encoder<float>> output = MakeEncoder<float>(lstmOutputInfo, currentOutputData); std::unique_ptr<Decoder<float>> outputDecoder = MakeDecoder<float>(lstmOutputInfo, currentOutputData); @@ -296,7 +298,7 @@ void RefUnidirectionalSequenceLstmWorkload::Execute(std::vector<ITensorHandle*> { // Permute Output back to batch major const PermutationVector& mappings = {1U, 0U, 2U}; - auto outputData = reinterpret_cast<float*>(outputs[0]->Map()); + auto outputData = reinterpret_cast<float*>(outputs[2]->Map()); std::vector<float> outputValue(outputData, outputData + outputInfo.GetNumElements()); outputShape = armnnUtils::Permuted(outputInfo.GetShape(), mappings); outputInfo.SetShape(outputShape); |