aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNarumol Prangnawarat <narumol.prangnawarat@arm.com>2023-05-29 15:54:57 +0100
committerNarumol Prangnawarat <narumol.prangnawarat@arm.com>2023-06-07 13:35:30 +0000
commit71311e4c47cf54a80f609d00c34e0da8b6e7d86c (patch)
tree366ad6a6f08951c1b3eccd6e3240581d2f0969b1
parentd3f5a0777a68a4f4f1e5c5925173764eeabe1c45 (diff)
downloadarmnn-71311e4c47cf54a80f609d00c34e0da8b6e7d86c.tar.gz
Fix incorrect validation of Unidirectional Sequence LSTM on Cl and Neon
Signed-off-by: Narumol Prangnawarat <narumol.prangnawarat@arm.com> Change-Id: I54c60fb98b9c560c300572f46d42b13aec7e402e
-rw-r--r--src/backends/cl/workloads/ClUnidirectionalSequenceLstmFloatWorkload.cpp29
-rw-r--r--src/backends/cl/workloads/ClUnidirectionalSequenceLstmFloatWorkload.hpp6
-rw-r--r--src/backends/neon/workloads/NeonUnidirectionalSequenceLstmFloatWorkload.cpp21
3 files changed, 41 insertions, 15 deletions
diff --git a/src/backends/cl/workloads/ClUnidirectionalSequenceLstmFloatWorkload.cpp b/src/backends/cl/workloads/ClUnidirectionalSequenceLstmFloatWorkload.cpp
index 289442e1cc..fb31d7c283 100644
--- a/src/backends/cl/workloads/ClUnidirectionalSequenceLstmFloatWorkload.cpp
+++ b/src/backends/cl/workloads/ClUnidirectionalSequenceLstmFloatWorkload.cpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
+// Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -508,17 +508,21 @@ arm_compute::Status
ClUnidirectionalSequenceLstmFloatWorkloadValidate(const TensorInfo& input,
const TensorInfo& outputStateIn,
const TensorInfo& cellStateIn,
+ const TensorInfo& outputStateOut,
+ const TensorInfo& cellStateOut,
const TensorInfo& output,
- const Optional<TensorInfo>& hiddenStateOutput,
- const Optional<TensorInfo>& cellStateOutput,
const UnidirectionalSequenceLstmDescriptor& descriptor,
const LstmInputParamsInfo& paramsInfo)
{
- IgnoreUnused(hiddenStateOutput, cellStateOutput);
-
TensorShape inputLayerShape = input.GetShape();
TensorShape outputLayerShape = output.GetShape();
+ if (inputLayerShape.GetNumDimensions() != 3)
+ {
+ return arm_compute::Status(arm_compute::ErrorCode::RUNTIME_ERROR,
+ "Unidirectional Sequence LSTM layer validate status failed.");
+ }
+
unsigned int maxTime = descriptor.m_TimeMajor?inputLayerShape[0]:inputLayerShape[1];
unsigned int batchSize = descriptor.m_TimeMajor?inputLayerShape[1]:inputLayerShape[0];
unsigned int inputSize = inputLayerShape[2];
@@ -544,7 +548,7 @@ ClUnidirectionalSequenceLstmFloatWorkloadValidate(const TensorInfo& input,
//
// Permute validate
//
- TensorInfo permuteOutInfo = TensorInfo(input);
+ TensorInfo permuteOutInfo = armnnUtils::Permuted(input, { 1U, 0U, 2U });
arm_compute::TensorInfo aclPermuteOutInfo = armcomputetensorutils::BuildArmComputeTensorInfo(permuteOutInfo);
if (!descriptor.m_TimeMajor)
{
@@ -610,9 +614,16 @@ ClUnidirectionalSequenceLstmFloatWorkloadValidate(const TensorInfo& input,
arm_compute::LSTMParams<arm_compute::ITensorInfo> lstm_params_info;
- const TensorInfo& scratchBuffer = TensorInfo(cellStateIn.GetShape(), input.GetDataType());
- const TensorInfo& outputStateOut = TensorInfo(outputStateIn.GetShape(), input.GetDataType());
- const TensorInfo& cellStateOut = TensorInfo(cellStateIn.GetShape(), input.GetDataType());
+ unsigned int numUnits = cellStateIn.GetShape()[1];
+ unsigned int scratchBufferFactor = 4;
+
+ if (descriptor.m_CifgEnabled)
+ {
+ // scratchBuffer = { batchSize, numUnits * 3 } with CIFG
+ scratchBufferFactor = 3;
+ }
+
+ const TensorInfo& scratchBuffer = TensorInfo({ batchSize, numUnits * scratchBufferFactor }, input.GetDataType());
// The inputs and outputs
const arm_compute::TensorInfo aclOutputStateInInfo = BuildArmComputeTensorInfo(outputStateIn);
diff --git a/src/backends/cl/workloads/ClUnidirectionalSequenceLstmFloatWorkload.hpp b/src/backends/cl/workloads/ClUnidirectionalSequenceLstmFloatWorkload.hpp
index f50e0a90b2..28844897d6 100644
--- a/src/backends/cl/workloads/ClUnidirectionalSequenceLstmFloatWorkload.hpp
+++ b/src/backends/cl/workloads/ClUnidirectionalSequenceLstmFloatWorkload.hpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
+// Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -87,9 +87,9 @@ arm_compute::Status
ClUnidirectionalSequenceLstmFloatWorkloadValidate(const TensorInfo& input,
const TensorInfo& outputStateIn,
const TensorInfo& cellStateIn,
+ const TensorInfo& outputStateOut,
+ const TensorInfo& cellStateOut,
const TensorInfo& output,
- const Optional<TensorInfo>& hiddenStateOutput,
- const Optional<TensorInfo>& cellStateOutput,
const UnidirectionalSequenceLstmDescriptor& descriptor,
const LstmInputParamsInfo& paramsInfo);
diff --git a/src/backends/neon/workloads/NeonUnidirectionalSequenceLstmFloatWorkload.cpp b/src/backends/neon/workloads/NeonUnidirectionalSequenceLstmFloatWorkload.cpp
index 7bdb2d5a5a..1905bcb659 100644
--- a/src/backends/neon/workloads/NeonUnidirectionalSequenceLstmFloatWorkload.cpp
+++ b/src/backends/neon/workloads/NeonUnidirectionalSequenceLstmFloatWorkload.cpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
+// Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -519,6 +519,12 @@ NeonUnidirectionalSequenceLstmFloatWorkloadValidate(const TensorInfo& input,
TensorShape inputLayerShape = input.GetShape();
TensorShape outputLayerShape = output.GetShape();
+ if (inputLayerShape.GetNumDimensions() != 3)
+ {
+ return arm_compute::Status(arm_compute::ErrorCode::RUNTIME_ERROR,
+ "Unidirectional Sequence LSTM layer validate status failed.");
+ }
+
unsigned int maxTime = descriptor.m_TimeMajor ? inputLayerShape[0] : inputLayerShape[1];
unsigned int batchSize = descriptor.m_TimeMajor ? inputLayerShape[1] : inputLayerShape[0];
unsigned int inputSize = inputLayerShape[2];
@@ -544,7 +550,7 @@ NeonUnidirectionalSequenceLstmFloatWorkloadValidate(const TensorInfo& input,
//
// Permute validate
//
- TensorInfo permuteOutInfo = TensorInfo(input);
+ TensorInfo permuteOutInfo = armnnUtils::Permuted(input, { 1U, 0U, 2U });
arm_compute::TensorInfo aclPermuteOutInfo = armcomputetensorutils::BuildArmComputeTensorInfo(permuteOutInfo);
if (!descriptor.m_TimeMajor)
{
@@ -609,7 +615,16 @@ NeonUnidirectionalSequenceLstmFloatWorkloadValidate(const TensorInfo& input,
arm_compute::LSTMParams<arm_compute::ITensorInfo> lstm_params_info;
- const TensorInfo& scratchBuffer = TensorInfo(cellStateIn.GetShape(), input.GetDataType());
+ unsigned int numUnits = cellStateIn.GetShape()[1];
+ unsigned int scratchBufferFactor = 4;
+
+ if (descriptor.m_CifgEnabled)
+ {
+ // scratchBuffer = { batchSize, numUnits * 3 } with CIFG
+ scratchBufferFactor = 3;
+ }
+
+ const TensorInfo& scratchBuffer = TensorInfo({ batchSize, numUnits * scratchBufferFactor }, input.GetDataType());
// The inputs and outputs
const arm_compute::TensorInfo aclOutputStateInInfo = BuildArmComputeTensorInfo(outputStateIn);