From 71311e4c47cf54a80f609d00c34e0da8b6e7d86c Mon Sep 17 00:00:00 2001 From: Narumol Prangnawarat Date: Mon, 29 May 2023 15:54:57 +0100 Subject: Fix incorrect validation of Unidirectional Sequence LSTM on Cl and Neon Signed-off-by: Narumol Prangnawarat Change-Id: I54c60fb98b9c560c300572f46d42b13aec7e402e --- .../ClUnidirectionalSequenceLstmFloatWorkload.cpp | 29 +++++++++++++++------- .../ClUnidirectionalSequenceLstmFloatWorkload.hpp | 6 ++--- ...NeonUnidirectionalSequenceLstmFloatWorkload.cpp | 21 +++++++++++++--- 3 files changed, 41 insertions(+), 15 deletions(-) diff --git a/src/backends/cl/workloads/ClUnidirectionalSequenceLstmFloatWorkload.cpp b/src/backends/cl/workloads/ClUnidirectionalSequenceLstmFloatWorkload.cpp index 289442e1cc..fb31d7c283 100644 --- a/src/backends/cl/workloads/ClUnidirectionalSequenceLstmFloatWorkload.cpp +++ b/src/backends/cl/workloads/ClUnidirectionalSequenceLstmFloatWorkload.cpp @@ -1,5 +1,5 @@ // -// Copyright © 2022 Arm Ltd and Contributors. All rights reserved. +// Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // @@ -508,17 +508,21 @@ arm_compute::Status ClUnidirectionalSequenceLstmFloatWorkloadValidate(const TensorInfo& input, const TensorInfo& outputStateIn, const TensorInfo& cellStateIn, + const TensorInfo& outputStateOut, + const TensorInfo& cellStateOut, const TensorInfo& output, - const Optional& hiddenStateOutput, - const Optional& cellStateOutput, const UnidirectionalSequenceLstmDescriptor& descriptor, const LstmInputParamsInfo& paramsInfo) { - IgnoreUnused(hiddenStateOutput, cellStateOutput); - TensorShape inputLayerShape = input.GetShape(); TensorShape outputLayerShape = output.GetShape(); + if (inputLayerShape.GetNumDimensions() != 3) + { + return arm_compute::Status(arm_compute::ErrorCode::RUNTIME_ERROR, + "Unidirectional Sequence LSTM layer validate status failed."); + } + unsigned int maxTime = descriptor.m_TimeMajor?inputLayerShape[0]:inputLayerShape[1]; unsigned int batchSize = descriptor.m_TimeMajor?inputLayerShape[1]:inputLayerShape[0]; unsigned int inputSize = inputLayerShape[2]; @@ -544,7 +548,7 @@ ClUnidirectionalSequenceLstmFloatWorkloadValidate(const TensorInfo& input, // // Permute validate // - TensorInfo permuteOutInfo = TensorInfo(input); + TensorInfo permuteOutInfo = armnnUtils::Permuted(input, { 1U, 0U, 2U }); arm_compute::TensorInfo aclPermuteOutInfo = armcomputetensorutils::BuildArmComputeTensorInfo(permuteOutInfo); if (!descriptor.m_TimeMajor) { @@ -610,9 +614,16 @@ ClUnidirectionalSequenceLstmFloatWorkloadValidate(const TensorInfo& input, arm_compute::LSTMParams lstm_params_info; - const TensorInfo& scratchBuffer = TensorInfo(cellStateIn.GetShape(), input.GetDataType()); - const TensorInfo& outputStateOut = TensorInfo(outputStateIn.GetShape(), input.GetDataType()); - const TensorInfo& cellStateOut = TensorInfo(cellStateIn.GetShape(), input.GetDataType()); + unsigned int numUnits = cellStateIn.GetShape()[1]; + unsigned int scratchBufferFactor = 4; + + if (descriptor.m_CifgEnabled) + { + // scratchBuffer = { batchSize, numUnits * 3 } with CIFG + scratchBufferFactor = 3; + } + + const TensorInfo& scratchBuffer = TensorInfo({ batchSize, numUnits * scratchBufferFactor }, input.GetDataType()); // The inputs and outputs const arm_compute::TensorInfo aclOutputStateInInfo = BuildArmComputeTensorInfo(outputStateIn); diff --git a/src/backends/cl/workloads/ClUnidirectionalSequenceLstmFloatWorkload.hpp b/src/backends/cl/workloads/ClUnidirectionalSequenceLstmFloatWorkload.hpp index f50e0a90b2..28844897d6 100644 --- a/src/backends/cl/workloads/ClUnidirectionalSequenceLstmFloatWorkload.hpp +++ b/src/backends/cl/workloads/ClUnidirectionalSequenceLstmFloatWorkload.hpp @@ -1,5 +1,5 @@ // -// Copyright © 2022 Arm Ltd and Contributors. All rights reserved. +// Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // @@ -87,9 +87,9 @@ arm_compute::Status ClUnidirectionalSequenceLstmFloatWorkloadValidate(const TensorInfo& input, const TensorInfo& outputStateIn, const TensorInfo& cellStateIn, + const TensorInfo& outputStateOut, + const TensorInfo& cellStateOut, const TensorInfo& output, - const Optional& hiddenStateOutput, - const Optional& cellStateOutput, const UnidirectionalSequenceLstmDescriptor& descriptor, const LstmInputParamsInfo& paramsInfo); diff --git a/src/backends/neon/workloads/NeonUnidirectionalSequenceLstmFloatWorkload.cpp b/src/backends/neon/workloads/NeonUnidirectionalSequenceLstmFloatWorkload.cpp index 7bdb2d5a5a..1905bcb659 100644 --- a/src/backends/neon/workloads/NeonUnidirectionalSequenceLstmFloatWorkload.cpp +++ b/src/backends/neon/workloads/NeonUnidirectionalSequenceLstmFloatWorkload.cpp @@ -1,5 +1,5 @@ // -// Copyright © 2022 Arm Ltd and Contributors. All rights reserved. +// Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // @@ -519,6 +519,12 @@ NeonUnidirectionalSequenceLstmFloatWorkloadValidate(const TensorInfo& input, TensorShape inputLayerShape = input.GetShape(); TensorShape outputLayerShape = output.GetShape(); + if (inputLayerShape.GetNumDimensions() != 3) + { + return arm_compute::Status(arm_compute::ErrorCode::RUNTIME_ERROR, + "Unidirectional Sequence LSTM layer validate status failed."); + } + unsigned int maxTime = descriptor.m_TimeMajor ? inputLayerShape[0] : inputLayerShape[1]; unsigned int batchSize = descriptor.m_TimeMajor ? inputLayerShape[1] : inputLayerShape[0]; unsigned int inputSize = inputLayerShape[2]; @@ -544,7 +550,7 @@ NeonUnidirectionalSequenceLstmFloatWorkloadValidate(const TensorInfo& input, // // Permute validate // - TensorInfo permuteOutInfo = TensorInfo(input); + TensorInfo permuteOutInfo = armnnUtils::Permuted(input, { 1U, 0U, 2U }); arm_compute::TensorInfo aclPermuteOutInfo = armcomputetensorutils::BuildArmComputeTensorInfo(permuteOutInfo); if (!descriptor.m_TimeMajor) { @@ -609,7 +615,16 @@ NeonUnidirectionalSequenceLstmFloatWorkloadValidate(const TensorInfo& input, arm_compute::LSTMParams lstm_params_info; - const TensorInfo& scratchBuffer = TensorInfo(cellStateIn.GetShape(), input.GetDataType()); + unsigned int numUnits = cellStateIn.GetShape()[1]; + unsigned int scratchBufferFactor = 4; + + if (descriptor.m_CifgEnabled) + { + // scratchBuffer = { batchSize, numUnits * 3 } with CIFG + scratchBufferFactor = 3; + } + + const TensorInfo& scratchBuffer = TensorInfo({ batchSize, numUnits * scratchBufferFactor }, input.GetDataType()); // The inputs and outputs const arm_compute::TensorInfo aclOutputStateInInfo = BuildArmComputeTensorInfo(outputStateIn); -- cgit v1.2.1