From 1299496996bc332f02218f926640a9255ed60310 Mon Sep 17 00:00:00 2001 From: Mike Kelly Date: Thu, 21 Apr 2022 11:57:09 +0100 Subject: IVGCVSW-6806 Add Unidirectional Sequence Lstm support to Neon * Corrected TensorInfo order for IsUnidirectionalSequenceLstmSupported * outputStateOut TensorInfo is not optional. * cellStateOut TensorInfo is not optional. * TensorInfo Order matches other QLSTM/LSTM layers. * Added missing parameters to UnidirectionalSequenceLstmOperator for delegate. * Added quantized UnidirectionalSequenceLstm support to Neon !android-nn-driver:7457 Signed-off-by: Mike Kelly Change-Id: I26dde1bb96793dd25eb9081ca5ae5f63752288c4 --- src/backends/backendsCommon/WorkloadData.cpp | 26 ++++---------------------- 1 file changed, 4 insertions(+), 22 deletions(-) (limited to 'src/backends/backendsCommon/WorkloadData.cpp') diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp index 70d7641f41..a2dcd63726 100644 --- a/src/backends/backendsCommon/WorkloadData.cpp +++ b/src/backends/backendsCommon/WorkloadData.cpp @@ -3860,38 +3860,20 @@ void UnidirectionalSequenceLstmQueueDescriptor::Validate(const WorkloadInfo& wor { throw InvalidArgumentException(descriptorName + ": Invalid number of inputs."); } - if (workloadInfo.m_OutputTensorInfos.size() != 1) + if (workloadInfo.m_OutputTensorInfos.size() != 3) { throw InvalidArgumentException(descriptorName + ": Invalid number of outputs."); } std::vector supportedTypes = { - DataType::Float32 + DataType::Float32, + DataType::QAsymmS8 }; // check for supported type of one input and match them with all the other input and output ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], supportedTypes, descriptorName); - // type matches all other inputs - for (uint32_t i = 1u; i < workloadInfo.m_InputTensorInfos.size(); ++i) - { - ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0], - workloadInfo.m_InputTensorInfos[i], - descriptorName, - "input_0", - "input_" + std::to_string(i)); - } - // type matches all other outputs - for (uint32_t i = 0u; i < workloadInfo.m_OutputTensorInfos.size(); ++i) - { - ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0], - workloadInfo.m_OutputTensorInfos[i], - "LstmQueueDescriptor", - "input_0", - "output_" + std::to_string(i)); - } - // Making sure clipping parameters have valid values. // == 0 means no clipping // > 0 means clipping @@ -3936,7 +3918,7 @@ void UnidirectionalSequenceLstmQueueDescriptor::Validate(const WorkloadInfo& wor descriptorName + " input_2"); // outputTensor - ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[0], 3, (timeStep * n_batch * n_output), + ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[2], 3, (timeStep * n_batch * n_output), descriptorName + " output_0"); // check that dimensions of inputs/outputs and QueueDescriptor data match with each other -- cgit v1.2.1