aboutsummaryrefslogtreecommitdiff
path: root/src/backends/neon/workloads/NeonUnidirectionalSequenceLstmFloatWorkload.cpp
diff options
context:
space:
mode:
authorNarumol Prangnawarat <narumol.prangnawarat@arm.com>2023-08-11 16:09:26 +0100
committerNarumol Prangnawarat <narumol.prangnawarat@arm.com>2023-08-21 13:34:54 +0100
commit5f94124ac11afbbf2d2a4cda539b316964802c76 (patch)
tree38d6f31517fe322cb0065a8c3e947801791a6405 /src/backends/neon/workloads/NeonUnidirectionalSequenceLstmFloatWorkload.cpp
parentc4f42340bd3d6664098c69d2fb044089aa39aea0 (diff)
downloadarmnn-5f94124ac11afbbf2d2a4cda539b316964802c76.tar.gz
IVGCVSW-7964 Fix UnidirectionalSequenceLstm
* Fix incorrect batch size and time size * Fix incorrect time major when max time =1 * Fix incorrect permutation * Fix incorrect scratch buffer * Unit tests Signed-off-by: Narumol Prangnawarat <narumol.prangnawarat@arm.com> Change-Id: I510fae55528be412a58d020e82bd283852e7800b
Diffstat (limited to 'src/backends/neon/workloads/NeonUnidirectionalSequenceLstmFloatWorkload.cpp')
-rw-r--r--src/backends/neon/workloads/NeonUnidirectionalSequenceLstmFloatWorkload.cpp5
1 files changed, 3 insertions, 2 deletions
diff --git a/src/backends/neon/workloads/NeonUnidirectionalSequenceLstmFloatWorkload.cpp b/src/backends/neon/workloads/NeonUnidirectionalSequenceLstmFloatWorkload.cpp
index e48425e3ee..bbdcd1f855 100644
--- a/src/backends/neon/workloads/NeonUnidirectionalSequenceLstmFloatWorkload.cpp
+++ b/src/backends/neon/workloads/NeonUnidirectionalSequenceLstmFloatWorkload.cpp
@@ -603,7 +603,8 @@ NeonUnidirectionalSequenceLstmFloatWorkloadValidate(const TensorInfo& input,
statusSplit = arm_compute::NESplit::validate(&aclPermuteOutInfo,
splitterOutputsTensorInfosPtr,
aclAxisSplit);
- } else
+ }
+ else
{
statusSplit = arm_compute::NESplit::validate(&aclInputInfo, splitterOutputsTensorInfosPtr, aclAxisSplit);
}
@@ -740,7 +741,7 @@ NeonUnidirectionalSequenceLstmFloatWorkloadValidate(const TensorInfo& input,
// Set input of LSTM to be first input ITensor.
// Set output of LSTM to be final output ITensor.
// LSTM input/output cannot be > 2 dimensions so need to resize its TensorInfo.
- if (maxTime == 1 && !descriptor.m_TimeMajor)
+ if (maxTime == 1 && descriptor.m_TimeMajor)
{
TensorShape inputShape = GetTensorShape(aclInputInfo.tensor_shape(), 1U);
TensorShape outputShape = GetTensorShape(aclOutputInfo.tensor_shape(), 1U);