aboutsummaryrefslogtreecommitdiff
path: root/src/backends/neon
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/neon')
-rw-r--r--src/backends/neon/test/NeonLayerTests.cpp2
-rw-r--r--src/backends/neon/workloads/NeonUnidirectionalSequenceLstmFloatWorkload.cpp5
-rw-r--r--src/backends/neon/workloads/NeonUnidirectionalSequenceLstmWorkload.cpp22
3 files changed, 24 insertions, 5 deletions
diff --git a/src/backends/neon/test/NeonLayerTests.cpp b/src/backends/neon/test/NeonLayerTests.cpp
index ae8352d09e..588c90be6d 100644
--- a/src/backends/neon/test/NeonLayerTests.cpp
+++ b/src/backends/neon/test/NeonLayerTests.cpp
@@ -1092,6 +1092,8 @@ ARMNN_AUTO_TEST_CASE_WITH_THF(UnidirectionalSequenceLstmLayerFloat32TimeMajorSin
UnidirectionalSequenceLstmLayerFloat32TimeMajorSingleBatchTest)
ARMNN_AUTO_TEST_CASE_WITH_THF(UnidirectionalSequenceLstmLayerFloat32BatchMajorSingleBatch,
UnidirectionalSequenceLstmLayerFloat32BatchMajorSingleBatchTest)
+ARMNN_AUTO_TEST_CASE_WITH_THF(UnidirectionalSequenceLstmLayerFloat32TimeMajorSingleTime,
+ UnidirectionalSequenceLstmLayerFloat32TimeMajorSingleTimeTest)
ARMNN_AUTO_TEST_CASE_WITH_THF(UnidirectionalSequenceLstmLayerFloat32,
UnidirectionalSequenceLstmLayerFloat32Test)
ARMNN_AUTO_TEST_CASE_WITH_THF(UnidirectionalSequenceLstmLayerFloat32TimeMajor,
diff --git a/src/backends/neon/workloads/NeonUnidirectionalSequenceLstmFloatWorkload.cpp b/src/backends/neon/workloads/NeonUnidirectionalSequenceLstmFloatWorkload.cpp
index e48425e3ee..bbdcd1f855 100644
--- a/src/backends/neon/workloads/NeonUnidirectionalSequenceLstmFloatWorkload.cpp
+++ b/src/backends/neon/workloads/NeonUnidirectionalSequenceLstmFloatWorkload.cpp
@@ -603,7 +603,8 @@ NeonUnidirectionalSequenceLstmFloatWorkloadValidate(const TensorInfo& input,
statusSplit = arm_compute::NESplit::validate(&aclPermuteOutInfo,
splitterOutputsTensorInfosPtr,
aclAxisSplit);
- } else
+ }
+ else
{
statusSplit = arm_compute::NESplit::validate(&aclInputInfo, splitterOutputsTensorInfosPtr, aclAxisSplit);
}
@@ -740,7 +741,7 @@ NeonUnidirectionalSequenceLstmFloatWorkloadValidate(const TensorInfo& input,
// Set input of LSTM to be first input ITensor.
// Set output of LSTM to be final output ITensor.
// LSTM input/output cannot be > 2 dimensions so need to resize its TensorInfo.
- if (maxTime == 1 && !descriptor.m_TimeMajor)
+ if (maxTime == 1 && descriptor.m_TimeMajor)
{
TensorShape inputShape = GetTensorShape(aclInputInfo.tensor_shape(), 1U);
TensorShape outputShape = GetTensorShape(aclOutputInfo.tensor_shape(), 1U);
diff --git a/src/backends/neon/workloads/NeonUnidirectionalSequenceLstmWorkload.cpp b/src/backends/neon/workloads/NeonUnidirectionalSequenceLstmWorkload.cpp
index 8a1747edd1..984a5dc549 100644
--- a/src/backends/neon/workloads/NeonUnidirectionalSequenceLstmWorkload.cpp
+++ b/src/backends/neon/workloads/NeonUnidirectionalSequenceLstmWorkload.cpp
@@ -500,6 +500,12 @@ NeonUnidirectionalSequenceLstmWorkloadValidate(const TensorInfo& input,
TensorShape inputLayerShape = input.GetShape();
TensorShape outputLayerShape = output.GetShape();
+ if (inputLayerShape.GetNumDimensions() != 3)
+ {
+ return arm_compute::Status(arm_compute::ErrorCode::RUNTIME_ERROR,
+ "Unidirectional Sequence LSTM layer validate status failed.");
+ }
+
unsigned int maxTime = descriptor.m_TimeMajor ? inputLayerShape[0] : inputLayerShape[1];
unsigned int batchSize = descriptor.m_TimeMajor ? inputLayerShape[1] : inputLayerShape[0];
unsigned int inputSize = inputLayerShape[2];
@@ -525,7 +531,7 @@ NeonUnidirectionalSequenceLstmWorkloadValidate(const TensorInfo& input,
//
// Permute validate
//
- TensorInfo permuteOutInfo = TensorInfo(input);
+ TensorInfo permuteOutInfo = armnnUtils::Permuted(input, { 1U, 0U, 2U });
arm_compute::TensorInfo aclPermuteOutInfo = armcomputetensorutils::BuildArmComputeTensorInfo(permuteOutInfo);
if (!descriptor.m_TimeMajor)
{
@@ -590,7 +596,17 @@ NeonUnidirectionalSequenceLstmWorkloadValidate(const TensorInfo& input,
arm_compute::LSTMParams<arm_compute::ITensorInfo> lstm_params_info;
- const TensorInfo& scratchBuffer = TensorInfo(cellStateIn.GetShape(), input.GetDataType());
+ unsigned int numUnits = cellStateIn.GetShape()[1];
+ unsigned int scratchBufferFactor = 4;
+
+ if (descriptor.m_CifgEnabled)
+ {
+ // scratchBuffer = { batchSize, numUnits * 3 } with CIFG
+ scratchBufferFactor = 3;
+ }
+
+ const TensorInfo& scratchBuffer = TensorInfo({ batchSize, numUnits * scratchBufferFactor }, input.GetDataType());
+
lstm_params_info.set_cell_clip_params(descriptor.m_ClippingThresCell);
lstm_params_info.set_projection_clip_params(descriptor.m_ClippingThresProj);
@@ -707,7 +723,7 @@ NeonUnidirectionalSequenceLstmWorkloadValidate(const TensorInfo& input,
// Set input of LSTM to be first input ITensor.
// Set output of LSTM to be final output ITensor.
// LSTM input/output cannot be > 2 dimensions so need to resize its TensorInfo.
- if (maxTime == 1 && !descriptor.m_TimeMajor)
+ if (maxTime == 1 && descriptor.m_TimeMajor)
{
TensorShape inputShape = GetTensorShape(aclInputInfo.tensor_shape(), 1U);
TensorShape outputShape = GetTensorShape(aclOutputInfo.tensor_shape(), 1U);