diff options
Diffstat (limited to 'delegate/opaque/src/UnidirectionalSequenceLstm.hpp')
-rw-r--r-- | delegate/opaque/src/UnidirectionalSequenceLstm.hpp | 2 |
1 files changed, 1 insertions, 1 deletions
diff --git a/delegate/opaque/src/UnidirectionalSequenceLstm.hpp b/delegate/opaque/src/UnidirectionalSequenceLstm.hpp index 2fd64c0dd0..19a57e87df 100644 --- a/delegate/opaque/src/UnidirectionalSequenceLstm.hpp +++ b/delegate/opaque/src/UnidirectionalSequenceLstm.hpp @@ -226,7 +226,7 @@ TfLiteStatus VisitUnidirectionalSequenceLstmOperator(DelegateData& delegateData, const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteOpaqueTensor(tfLiteInputTensor); const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteOpaqueTensor(tfLiteOutputTensor, true); - unsigned int batchSize = inputTensorInfo.GetShape()[0]; + unsigned int batchSize = desc.m_TimeMajor ? inputTensorInfo.GetShape()[1] : inputTensorInfo.GetShape()[0]; unsigned int outputSize = outputTensorInfo.GetShape()[2]; unsigned int numUnits = cellStateInInfo.GetShape()[1]; |