From 71311e4c47cf54a80f609d00c34e0da8b6e7d86c Mon Sep 17 00:00:00 2001 From: Narumol Prangnawarat Date: Mon, 29 May 2023 15:54:57 +0100 Subject: Fix incorrect validation of Unidirectional Sequence LSTM on Cl and Neon Signed-off-by: Narumol Prangnawarat Change-Id: I54c60fb98b9c560c300572f46d42b13aec7e402e --- .../NeonUnidirectionalSequenceLstmFloatWorkload.cpp | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) (limited to 'src/backends/neon') diff --git a/src/backends/neon/workloads/NeonUnidirectionalSequenceLstmFloatWorkload.cpp b/src/backends/neon/workloads/NeonUnidirectionalSequenceLstmFloatWorkload.cpp index 7bdb2d5a5a..1905bcb659 100644 --- a/src/backends/neon/workloads/NeonUnidirectionalSequenceLstmFloatWorkload.cpp +++ b/src/backends/neon/workloads/NeonUnidirectionalSequenceLstmFloatWorkload.cpp @@ -1,5 +1,5 @@ // -// Copyright © 2022 Arm Ltd and Contributors. All rights reserved. +// Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // @@ -519,6 +519,12 @@ NeonUnidirectionalSequenceLstmFloatWorkloadValidate(const TensorInfo& input, TensorShape inputLayerShape = input.GetShape(); TensorShape outputLayerShape = output.GetShape(); + if (inputLayerShape.GetNumDimensions() != 3) + { + return arm_compute::Status(arm_compute::ErrorCode::RUNTIME_ERROR, + "Unidirectional Sequence LSTM layer validate status failed."); + } + unsigned int maxTime = descriptor.m_TimeMajor ? inputLayerShape[0] : inputLayerShape[1]; unsigned int batchSize = descriptor.m_TimeMajor ? inputLayerShape[1] : inputLayerShape[0]; unsigned int inputSize = inputLayerShape[2]; @@ -544,7 +550,7 @@ NeonUnidirectionalSequenceLstmFloatWorkloadValidate(const TensorInfo& input, // // Permute validate // - TensorInfo permuteOutInfo = TensorInfo(input); + TensorInfo permuteOutInfo = armnnUtils::Permuted(input, { 1U, 0U, 2U }); arm_compute::TensorInfo aclPermuteOutInfo = armcomputetensorutils::BuildArmComputeTensorInfo(permuteOutInfo); if (!descriptor.m_TimeMajor) { @@ -609,7 +615,16 @@ NeonUnidirectionalSequenceLstmFloatWorkloadValidate(const TensorInfo& input, arm_compute::LSTMParams lstm_params_info; - const TensorInfo& scratchBuffer = TensorInfo(cellStateIn.GetShape(), input.GetDataType()); + unsigned int numUnits = cellStateIn.GetShape()[1]; + unsigned int scratchBufferFactor = 4; + + if (descriptor.m_CifgEnabled) + { + // scratchBuffer = { batchSize, numUnits * 3 } with CIFG + scratchBufferFactor = 3; + } + + const TensorInfo& scratchBuffer = TensorInfo({ batchSize, numUnits * scratchBufferFactor }, input.GetDataType()); // The inputs and outputs const arm_compute::TensorInfo aclOutputStateInInfo = BuildArmComputeTensorInfo(outputStateIn); -- cgit v1.2.1