aboutsummaryrefslogtreecommitdiff
path: root/delegate/opaque/src/UnidirectionalSequenceLstm.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'delegate/opaque/src/UnidirectionalSequenceLstm.hpp')
-rw-r--r--delegate/opaque/src/UnidirectionalSequenceLstm.hpp2
1 files changed, 1 insertions, 1 deletions
diff --git a/delegate/opaque/src/UnidirectionalSequenceLstm.hpp b/delegate/opaque/src/UnidirectionalSequenceLstm.hpp
index 2fd64c0dd0..19a57e87df 100644
--- a/delegate/opaque/src/UnidirectionalSequenceLstm.hpp
+++ b/delegate/opaque/src/UnidirectionalSequenceLstm.hpp
@@ -226,7 +226,7 @@ TfLiteStatus VisitUnidirectionalSequenceLstmOperator(DelegateData& delegateData,
const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteOpaqueTensor(tfLiteInputTensor);
const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteOpaqueTensor(tfLiteOutputTensor, true);
- unsigned int batchSize = inputTensorInfo.GetShape()[0];
+ unsigned int batchSize = desc.m_TimeMajor ? inputTensorInfo.GetShape()[1] : inputTensorInfo.GetShape()[0];
unsigned int outputSize = outputTensorInfo.GetShape()[2];
unsigned int numUnits = cellStateInInfo.GetShape()[1];