diff options
Diffstat (limited to 'src/backends/neon')
3 files changed, 24 insertions, 5 deletions
diff --git a/src/backends/neon/test/NeonLayerTests.cpp b/src/backends/neon/test/NeonLayerTests.cpp index ae8352d09e..588c90be6d 100644 --- a/src/backends/neon/test/NeonLayerTests.cpp +++ b/src/backends/neon/test/NeonLayerTests.cpp @@ -1092,6 +1092,8 @@ ARMNN_AUTO_TEST_CASE_WITH_THF(UnidirectionalSequenceLstmLayerFloat32TimeMajorSin UnidirectionalSequenceLstmLayerFloat32TimeMajorSingleBatchTest) ARMNN_AUTO_TEST_CASE_WITH_THF(UnidirectionalSequenceLstmLayerFloat32BatchMajorSingleBatch, UnidirectionalSequenceLstmLayerFloat32BatchMajorSingleBatchTest) +ARMNN_AUTO_TEST_CASE_WITH_THF(UnidirectionalSequenceLstmLayerFloat32TimeMajorSingleTime, + UnidirectionalSequenceLstmLayerFloat32TimeMajorSingleTimeTest) ARMNN_AUTO_TEST_CASE_WITH_THF(UnidirectionalSequenceLstmLayerFloat32, UnidirectionalSequenceLstmLayerFloat32Test) ARMNN_AUTO_TEST_CASE_WITH_THF(UnidirectionalSequenceLstmLayerFloat32TimeMajor, diff --git a/src/backends/neon/workloads/NeonUnidirectionalSequenceLstmFloatWorkload.cpp b/src/backends/neon/workloads/NeonUnidirectionalSequenceLstmFloatWorkload.cpp index e48425e3ee..bbdcd1f855 100644 --- a/src/backends/neon/workloads/NeonUnidirectionalSequenceLstmFloatWorkload.cpp +++ b/src/backends/neon/workloads/NeonUnidirectionalSequenceLstmFloatWorkload.cpp @@ -603,7 +603,8 @@ NeonUnidirectionalSequenceLstmFloatWorkloadValidate(const TensorInfo& input, statusSplit = arm_compute::NESplit::validate(&aclPermuteOutInfo, splitterOutputsTensorInfosPtr, aclAxisSplit); - } else + } + else { statusSplit = arm_compute::NESplit::validate(&aclInputInfo, splitterOutputsTensorInfosPtr, aclAxisSplit); } @@ -740,7 +741,7 @@ NeonUnidirectionalSequenceLstmFloatWorkloadValidate(const TensorInfo& input, // Set input of LSTM to be first input ITensor. // Set output of LSTM to be final output ITensor. // LSTM input/output cannot be > 2 dimensions so need to resize its TensorInfo. - if (maxTime == 1 && !descriptor.m_TimeMajor) + if (maxTime == 1 && descriptor.m_TimeMajor) { TensorShape inputShape = GetTensorShape(aclInputInfo.tensor_shape(), 1U); TensorShape outputShape = GetTensorShape(aclOutputInfo.tensor_shape(), 1U); diff --git a/src/backends/neon/workloads/NeonUnidirectionalSequenceLstmWorkload.cpp b/src/backends/neon/workloads/NeonUnidirectionalSequenceLstmWorkload.cpp index 8a1747edd1..984a5dc549 100644 --- a/src/backends/neon/workloads/NeonUnidirectionalSequenceLstmWorkload.cpp +++ b/src/backends/neon/workloads/NeonUnidirectionalSequenceLstmWorkload.cpp @@ -500,6 +500,12 @@ NeonUnidirectionalSequenceLstmWorkloadValidate(const TensorInfo& input, TensorShape inputLayerShape = input.GetShape(); TensorShape outputLayerShape = output.GetShape(); + if (inputLayerShape.GetNumDimensions() != 3) + { + return arm_compute::Status(arm_compute::ErrorCode::RUNTIME_ERROR, + "Unidirectional Sequence LSTM layer validate status failed."); + } + unsigned int maxTime = descriptor.m_TimeMajor ? inputLayerShape[0] : inputLayerShape[1]; unsigned int batchSize = descriptor.m_TimeMajor ? inputLayerShape[1] : inputLayerShape[0]; unsigned int inputSize = inputLayerShape[2]; @@ -525,7 +531,7 @@ NeonUnidirectionalSequenceLstmWorkloadValidate(const TensorInfo& input, // // Permute validate // - TensorInfo permuteOutInfo = TensorInfo(input); + TensorInfo permuteOutInfo = armnnUtils::Permuted(input, { 1U, 0U, 2U }); arm_compute::TensorInfo aclPermuteOutInfo = armcomputetensorutils::BuildArmComputeTensorInfo(permuteOutInfo); if (!descriptor.m_TimeMajor) { @@ -590,7 +596,17 @@ NeonUnidirectionalSequenceLstmWorkloadValidate(const TensorInfo& input, arm_compute::LSTMParams<arm_compute::ITensorInfo> lstm_params_info; - const TensorInfo& scratchBuffer = TensorInfo(cellStateIn.GetShape(), input.GetDataType()); + unsigned int numUnits = cellStateIn.GetShape()[1]; + unsigned int scratchBufferFactor = 4; + + if (descriptor.m_CifgEnabled) + { + // scratchBuffer = { batchSize, numUnits * 3 } with CIFG + scratchBufferFactor = 3; + } + + const TensorInfo& scratchBuffer = TensorInfo({ batchSize, numUnits * scratchBufferFactor }, input.GetDataType()); + lstm_params_info.set_cell_clip_params(descriptor.m_ClippingThresCell); lstm_params_info.set_projection_clip_params(descriptor.m_ClippingThresProj); @@ -707,7 +723,7 @@ NeonUnidirectionalSequenceLstmWorkloadValidate(const TensorInfo& input, // Set input of LSTM to be first input ITensor. // Set output of LSTM to be final output ITensor. // LSTM input/output cannot be > 2 dimensions so need to resize its TensorInfo. - if (maxTime == 1 && !descriptor.m_TimeMajor) + if (maxTime == 1 && descriptor.m_TimeMajor) { TensorShape inputShape = GetTensorShape(aclInputInfo.tensor_shape(), 1U); TensorShape outputShape = GetTensorShape(aclOutputInfo.tensor_shape(), 1U); |