From 5f94124ac11afbbf2d2a4cda539b316964802c76 Mon Sep 17 00:00:00 2001 From: Narumol Prangnawarat Date: Fri, 11 Aug 2023 16:09:26 +0100 Subject: IVGCVSW-7964 Fix UnidirectionalSequenceLstm * Fix incorrect batch size and time size * Fix incorrect time major when max time =1 * Fix incorrect permutation * Fix incorrect scratch buffer * Unit tests Signed-off-by: Narumol Prangnawarat Change-Id: I510fae55528be412a58d020e82bd283852e7800b --- .../neon/workloads/NeonUnidirectionalSequenceLstmFloatWorkload.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) (limited to 'src/backends/neon/workloads/NeonUnidirectionalSequenceLstmFloatWorkload.cpp') diff --git a/src/backends/neon/workloads/NeonUnidirectionalSequenceLstmFloatWorkload.cpp b/src/backends/neon/workloads/NeonUnidirectionalSequenceLstmFloatWorkload.cpp index e48425e3ee..bbdcd1f855 100644 --- a/src/backends/neon/workloads/NeonUnidirectionalSequenceLstmFloatWorkload.cpp +++ b/src/backends/neon/workloads/NeonUnidirectionalSequenceLstmFloatWorkload.cpp @@ -603,7 +603,8 @@ NeonUnidirectionalSequenceLstmFloatWorkloadValidate(const TensorInfo& input, statusSplit = arm_compute::NESplit::validate(&aclPermuteOutInfo, splitterOutputsTensorInfosPtr, aclAxisSplit); - } else + } + else { statusSplit = arm_compute::NESplit::validate(&aclInputInfo, splitterOutputsTensorInfosPtr, aclAxisSplit); } @@ -740,7 +741,7 @@ NeonUnidirectionalSequenceLstmFloatWorkloadValidate(const TensorInfo& input, // Set input of LSTM to be first input ITensor. // Set output of LSTM to be final output ITensor. // LSTM input/output cannot be > 2 dimensions so need to resize its TensorInfo. - if (maxTime == 1 && !descriptor.m_TimeMajor) + if (maxTime == 1 && descriptor.m_TimeMajor) { TensorShape inputShape = GetTensorShape(aclInputInfo.tensor_shape(), 1U); TensorShape outputShape = GetTensorShape(aclOutputInfo.tensor_shape(), 1U); -- cgit v1.2.1