diff options
Diffstat (limited to 'delegate/classic/src/UnidirectionalSequenceLstm.hpp')
-rw-r--r-- | delegate/classic/src/UnidirectionalSequenceLstm.hpp | 2 |
1 files changed, 1 insertions, 1 deletions
diff --git a/delegate/classic/src/UnidirectionalSequenceLstm.hpp b/delegate/classic/src/UnidirectionalSequenceLstm.hpp index 5fa6bb0260..3529640aa1 100644 --- a/delegate/classic/src/UnidirectionalSequenceLstm.hpp +++ b/delegate/classic/src/UnidirectionalSequenceLstm.hpp @@ -184,7 +184,7 @@ TfLiteStatus VisitUnidirectionalSequenceLstmOperator(DelegateData& delegateData, const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor); const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true); - unsigned int batchSize = inputTensorInfo.GetShape()[0]; + unsigned int batchSize = desc.m_TimeMajor ? inputTensorInfo.GetShape()[1] : inputTensorInfo.GetShape()[0]; unsigned int outputSize = outputTensorInfo.GetShape()[2]; unsigned int numUnits = cellStateInInfo.GetShape()[1]; |