From 8ed39ae450a077c7e4d672b5f05ff1d68ee67aab Mon Sep 17 00:00:00 2001 From: Narumol Prangnawarat Date: Thu, 15 Jul 2021 16:16:25 +0100 Subject: MLCE-530 Add front end support for UnidirectionalSequenceLstm on ArmNN Signed-off-by: Narumol Prangnawarat Change-Id: I57bcbdec3eb0155f41af0fe7d6abf9bac2ec86eb --- src/backends/backendsCommon/LayerSupportBase.cpp | 13 + src/backends/backendsCommon/LayerSupportBase.hpp | 11 + src/backends/backendsCommon/WorkloadData.cpp | 276 ++++++++++++++++++++- src/backends/backendsCommon/WorkloadData.hpp | 52 ++++ src/backends/backendsCommon/WorkloadFactory.cpp | 148 +++++++++++ src/backends/backendsCommon/WorkloadFactory.hpp | 4 + .../test/IsLayerSupportedTestImpl.hpp | 53 ++++ 7 files changed, 555 insertions(+), 2 deletions(-) (limited to 'src/backends') diff --git a/src/backends/backendsCommon/LayerSupportBase.cpp b/src/backends/backendsCommon/LayerSupportBase.cpp index 8a24e1161b..138d45367e 100644 --- a/src/backends/backendsCommon/LayerSupportBase.cpp +++ b/src/backends/backendsCommon/LayerSupportBase.cpp @@ -678,4 +678,17 @@ bool LayerSupportBase::IsTransposeSupported(const TensorInfo&, // input return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported); } +bool LayerSupportBase::IsUnidirectionalSequenceLstmSupported(const TensorInfo&, // input + const TensorInfo&, // outputStateIn + const TensorInfo&, // cellStateIn + const TensorInfo&, // output + const Optional&, // hiddenStateOut + const Optional&, // cellStateOut + const LstmDescriptor&, // descriptor + const LstmInputParamsInfo&, // paramsInfo + Optional reasonIfUnsupported) const +{ + return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported); +} + } // namespace armnn diff --git a/src/backends/backendsCommon/LayerSupportBase.hpp b/src/backends/backendsCommon/LayerSupportBase.hpp index 0277a782a1..533a2c6bdd 100644 --- a/src/backends/backendsCommon/LayerSupportBase.hpp +++ b/src/backends/backendsCommon/LayerSupportBase.hpp @@ -417,6 +417,17 @@ public: const TransposeDescriptor& descriptor, Optional reasonIfUnsupported = EmptyOptional()) const override; + bool IsUnidirectionalSequenceLstmSupported( + const TensorInfo& input, + const TensorInfo& outputStateIn, + const TensorInfo& cellStateIn, + const TensorInfo& output, + const Optional& hiddenStateOutput, + const Optional& cellStateOutput, + const LstmDescriptor& descriptor, + const LstmInputParamsInfo& paramsInfo, + Optional reasonIfUnsupported = EmptyOptional()) const override; + }; } // namespace armnn diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp index 8c78136185..3fe0823b03 100644 --- a/src/backends/backendsCommon/WorkloadData.cpp +++ b/src/backends/backendsCommon/WorkloadData.cpp @@ -1959,7 +1959,6 @@ void LstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const throw InvalidArgumentException(descriptorName + ": negative projection clipping threshold is invalid"); } - // Inferring batch size, number of outputs and number of cells from the inputs. const uint32_t n_input = workloadInfo.m_InputTensorInfos[0].GetShape()[1]; const uint32_t n_batch = workloadInfo.m_InputTensorInfos[0].GetShape()[0]; @@ -1991,7 +1990,6 @@ void LstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[3], 2, (n_batch * n_output), descriptorName + " output_3"); - // check that dimensions of inputs/outputs and QueueDescriptor data match with each other if ( m_InputToInputWeights ) { @@ -3741,4 +3739,278 @@ void ReduceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output"); } +void UnidirectionalSequenceLstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const +{ + // Modified from LstmQueueDescriptor::Validate to support UnidirectionalSequenceLstm + + const std::string descriptorName{"UnidirectionalSequenceLstmQueueDescriptor"}; + + // check dimensions of all inputs and outputs + if (workloadInfo.m_InputTensorInfos.size() != 3) + { + throw InvalidArgumentException(descriptorName + ": Invalid number of inputs."); + } + if (workloadInfo.m_OutputTensorInfos.size() != 1) + { + throw InvalidArgumentException(descriptorName + ": Invalid number of outputs."); + } + + std::vector supportedTypes = + { + DataType::Float16, + DataType::Float32, + DataType::QAsymmS8 + }; + + // check for supported type of one input and match them with all the other input and output + ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], supportedTypes, descriptorName); + + // type matches all other inputs + for (uint32_t i = 1u; i < workloadInfo.m_InputTensorInfos.size(); ++i) + { + ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0], + workloadInfo.m_InputTensorInfos[i], + descriptorName, + "input_0", + "input_" + std::to_string(i)); + } + // type matches all other outputs + for (uint32_t i = 0u; i < workloadInfo.m_OutputTensorInfos.size(); ++i) + { + ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0], + workloadInfo.m_OutputTensorInfos[i], + "LstmQueueDescriptor", + "input_0", + "output_" + std::to_string(i)); + } + + // Making sure clipping parameters have valid values. + // == 0 means no clipping + // > 0 means clipping + if (m_Parameters.m_ClippingThresCell < 0.0f) + { + throw InvalidArgumentException(descriptorName + ": negative cell clipping threshold is invalid"); + } + if (m_Parameters.m_ClippingThresProj < 0.0f) + { + throw InvalidArgumentException(descriptorName + ": negative projection clipping threshold is invalid"); + } + + unsigned int batchIndx = 0; + unsigned int inputIndx = 1; + uint32_t timeStep = 1; + unsigned int timeIndx = 1; + inputIndx = 2; + if (m_Parameters.m_TimeMajor) + { + batchIndx = 1; + timeIndx = 0; + + } + timeStep = workloadInfo.m_InputTensorInfos[0].GetShape()[timeIndx]; + + // Inferring batch size, number of outputs and number of cells from the inputs. + const uint32_t n_input = workloadInfo.m_InputTensorInfos[0].GetShape()[inputIndx]; + const uint32_t n_batch = workloadInfo.m_InputTensorInfos[0].GetShape()[batchIndx]; + ValidatePointer(m_InputToOutputWeights, "Null pointer check", "InputToOutputWeights"); + const uint32_t n_cell = m_InputToOutputWeights->GetShape()[0]; + ValidatePointer(m_RecurrentToOutputWeights, "Null pointer check", "RecurrentToOutputWeights"); + const uint32_t n_output = m_RecurrentToOutputWeights->GetShape()[1]; + + // input tensor + ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[0], 3, (timeStep * n_batch * n_input), + descriptorName + " input_0"); + // outputStateInTensor + ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[1], 2, (n_batch * n_output), + descriptorName + " input_1"); + // outputStateInTensor + ValidateTensorNumDimNumElem(workloadInfo.m_InputTensorInfos[2], 2, (n_batch * n_cell), + descriptorName + " input_2"); + + // outputTensor + ValidateTensorNumDimNumElem(workloadInfo.m_OutputTensorInfos[0], 3, (timeStep * n_batch * n_output), + descriptorName + " output_0"); + + // check that dimensions of inputs/outputs and QueueDescriptor data match with each other + if ( m_InputToInputWeights ) + { + ValidateTensorNumDimNumElem(m_InputToInputWeights->GetTensorInfo(), 2, + (n_cell * n_input), "InputLayerNormWeights"); + } + + ValidatePointer(m_InputToForgetWeights, "Null pointer check", "InputToForgetWeights"); + ValidateTensorNumDimNumElem(m_InputToForgetWeights->GetTensorInfo(), 2, + (n_cell * n_input), "InputToForgetWeights"); + + ValidatePointer(m_InputToCellWeights, "Null pointer check", "InputToCellWeights"); + ValidateTensorNumDimNumElem(m_InputToCellWeights->GetTensorInfo(), 2, + (n_cell * n_input), "InputToCellWeights"); + + if ( m_RecurrentToInputWeights ) + { + ValidateTensorNumDimNumElem(m_RecurrentToInputWeights->GetTensorInfo(), 2, + (n_cell * n_output), "RecurrentToInputWeights"); + } + + ValidatePointer(m_RecurrentToForgetWeights, "Null pointer check", "RecurrentToForgetWeights"); + ValidateTensorNumDimNumElem(m_RecurrentToForgetWeights->GetTensorInfo(), 2, + (n_cell * n_output), "RecurrentToForgetWeights"); + + ValidatePointer(m_RecurrentToCellWeights, "Null pointer check", "RecurrentToCellWeights"); + ValidateTensorNumDimNumElem(m_RecurrentToCellWeights->GetTensorInfo(), 2, + (n_cell * n_output), "RecurrentToCellWeights"); + + // Make sure the input-gate's parameters are either both present (regular + // LSTM) or not at all (CIFG-LSTM). And CifgEnable is set accordingly. + bool cifg_weights_all_or_none = ((m_InputToInputWeights && m_RecurrentToInputWeights && + !m_Parameters.m_CifgEnabled) || + (!m_InputToInputWeights && !m_RecurrentToInputWeights && + m_Parameters.m_CifgEnabled)); + if (!cifg_weights_all_or_none) + { + throw InvalidArgumentException(descriptorName + ": Input-Gate's parameters InputToInputWeights and " + "RecurrentToInputWeights must either both be present (regular LSTM) " + "or both not present (CIFG-LSTM). In addition CifgEnable must be set " + "accordingly."); + } + + if ( m_CellToInputWeights ) + { + ValidateTensorNumDimNumElem(m_CellToInputWeights->GetTensorInfo(), 1, + n_cell, "CellToInputWeights"); + } + if ( m_CellToForgetWeights ) + { + ValidateTensorNumDimNumElem(m_CellToForgetWeights->GetTensorInfo(), 1, + n_cell, "CellToForgetWeights"); + } + if ( m_CellToOutputWeights ) + { + ValidateTensorNumDimNumElem(m_CellToOutputWeights->GetTensorInfo(), 1, + n_cell, "CellToOutputWeights"); + } + + // Making sure the peephole weights are there all or none. And PeepholeEnable is set accordingly. + bool peephole_weights_all_or_none = + (((m_CellToInputWeights || m_Parameters.m_CifgEnabled) && m_CellToForgetWeights + && m_CellToOutputWeights && m_Parameters.m_PeepholeEnabled) + || ( !m_CellToInputWeights && !m_CellToForgetWeights + && !m_CellToOutputWeights && !m_Parameters.m_PeepholeEnabled)); + if (!peephole_weights_all_or_none) + { + throw InvalidArgumentException(descriptorName + ": Invalid combination of peephole parameters."); + } + + // Make sure the input gate bias is present only when not a CIFG-LSTM. + if (m_Parameters.m_CifgEnabled) + { + if (m_InputGateBias) + { + throw InvalidArgumentException(descriptorName + ": InputGateBias is present and CIFG-LSTM is enabled."); + } + } + else + { + if (!m_InputGateBias) + { + throw InvalidArgumentException(descriptorName + ": If CIFG-LSTM is disabled InputGateBias " + "must be present."); + } + ValidateTensorNumDimNumElem(m_InputGateBias->GetTensorInfo(), 1, + n_cell, "InputGateBias"); + } + + ValidatePointer(m_ForgetGateBias, "Null pointer check", "ForgetGateBias"); + ValidateTensorNumDimNumElem(m_ForgetGateBias->GetTensorInfo(), 1, n_cell, "ForgetGateBias"); + + ValidatePointer(m_CellBias, "Null pointer check", "CellBias"); + ValidateTensorNumDimNumElem(m_CellBias->GetTensorInfo(), 1, n_cell, "CellBias"); + + ValidatePointer(m_OutputGateBias, "Null pointer check", "OutputGateBias"); + ValidateTensorNumDimNumElem(m_OutputGateBias->GetTensorInfo(), 1, n_cell, "OutputGateBias"); + + if (m_ProjectionWeights) + { + ValidateTensorNumDimNumElem(m_ProjectionWeights->GetTensorInfo(), 2, + (n_cell * n_output), "ProjectionWeights"); + } + if (m_ProjectionBias) + { + ValidateTensorNumDimNumElem(m_ProjectionBias->GetTensorInfo(), 1, n_output, "ProjectionBias"); + } + + // Making sure the projection tensors are consistent: + // 1) If projection weight is not present, then projection bias should not be + // present. + // 2) If projection weight is present, then projection bias is optional. + bool projecton_tensors_consistent = ((!m_ProjectionWeights && !m_ProjectionBias && + !m_Parameters.m_ProjectionEnabled) + || (m_ProjectionWeights && !m_ProjectionBias && + m_Parameters.m_ProjectionEnabled) + || (m_ProjectionWeights && m_ProjectionBias && + m_Parameters.m_ProjectionEnabled)); + if (!projecton_tensors_consistent) + { + throw InvalidArgumentException(descriptorName + ": Projection tensors are inconsistent."); + } + + // The four layer normalization weights either all have values or none of them have values. Additionally, if + // CIFG is used, input layer normalization weights tensor is omitted and the other layer normalization weights + // either all have values or none of them have values. Layer normalization is used when the values of all the + // layer normalization weights are present + if (m_InputLayerNormWeights) + { + ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(), 1, n_cell, "InputLayerNormWeights"); + } + if (m_ForgetLayerNormWeights) + { + ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights"); + } + if (m_CellLayerNormWeights) + { + ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights"); + } + if (m_OutputLayerNormWeights) + { + ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights"); + } + + if (m_Parameters.m_LayerNormEnabled) + { + if (!m_Parameters.m_CifgEnabled) + { + if (!m_InputLayerNormWeights) + { + throw InvalidArgumentException(descriptorName + ": Layer normalisation is enabled and CIFG-LSTM is " + "disabled but InputLayerNormWeights are not present"); + } + ValidateTensorNumDimNumElem(m_InputLayerNormWeights->GetTensorInfo(), + 1, n_cell, "InputLayerNormWeights"); + } + else if (m_InputLayerNormWeights) + { + throw InvalidArgumentException(descriptorName + ":InputLayerNormWeights are present while CIFG is " + "enabled"); + } + + ValidatePointer(m_ForgetLayerNormWeights, "Null pointer check layer normalisation enabled", + "ForgetLayerNormWeights"); + ValidateTensorNumDimNumElem(m_ForgetLayerNormWeights->GetTensorInfo(), 1, n_cell, "ForgetLayerNormWeights"); + + ValidatePointer(m_OutputLayerNormWeights, "Null pointer check layer normalisation enabled", + "OutputLayerNormWeights"); + ValidateTensorNumDimNumElem(m_OutputLayerNormWeights->GetTensorInfo(), 1, n_cell, "OutputLayerNormWeights"); + + ValidatePointer(m_CellLayerNormWeights, "Null pointer check layer normalisation enabled", + "CellLayerNormWeights"); + ValidateTensorNumDimNumElem(m_CellLayerNormWeights->GetTensorInfo(), 1, n_cell, "CellLayerNormWeights"); + } + else if (m_InputLayerNormWeights || m_ForgetLayerNormWeights || m_OutputLayerNormWeights || m_CellLayerNormWeights) + { + throw InvalidArgumentException(descriptorName + ": Layer normalisation is disabled but one or more layer " + "normalisation weights are present."); + } +} + + } // namespace armnn \ No newline at end of file diff --git a/src/backends/backendsCommon/WorkloadData.hpp b/src/backends/backendsCommon/WorkloadData.hpp index 36653bdc0d..78da00be5d 100644 --- a/src/backends/backendsCommon/WorkloadData.hpp +++ b/src/backends/backendsCommon/WorkloadData.hpp @@ -695,4 +695,56 @@ struct ShapeQueueDescriptor : QueueDescriptor void Validate(const WorkloadInfo& workloadInfo) const; }; +struct UnidirectionalSequenceLstmQueueDescriptor : QueueDescriptorWithParameters +{ + UnidirectionalSequenceLstmQueueDescriptor() + : m_InputToInputWeights(nullptr) + , m_InputToForgetWeights(nullptr) + , m_InputToCellWeights(nullptr) + , m_InputToOutputWeights(nullptr) + , m_RecurrentToInputWeights(nullptr) + , m_RecurrentToForgetWeights(nullptr) + , m_RecurrentToCellWeights(nullptr) + , m_RecurrentToOutputWeights(nullptr) + , m_CellToInputWeights(nullptr) + , m_CellToForgetWeights(nullptr) + , m_CellToOutputWeights(nullptr) + , m_InputGateBias(nullptr) + , m_ForgetGateBias(nullptr) + , m_CellBias(nullptr) + , m_OutputGateBias(nullptr) + , m_ProjectionWeights(nullptr) + , m_ProjectionBias(nullptr) + , m_InputLayerNormWeights(nullptr) + , m_ForgetLayerNormWeights(nullptr) + , m_CellLayerNormWeights(nullptr) + , m_OutputLayerNormWeights(nullptr) + { + } + + const ConstTensorHandle* m_InputToInputWeights; + const ConstTensorHandle* m_InputToForgetWeights; + const ConstTensorHandle* m_InputToCellWeights; + const ConstTensorHandle* m_InputToOutputWeights; + const ConstTensorHandle* m_RecurrentToInputWeights; + const ConstTensorHandle* m_RecurrentToForgetWeights; + const ConstTensorHandle* m_RecurrentToCellWeights; + const ConstTensorHandle* m_RecurrentToOutputWeights; + const ConstTensorHandle* m_CellToInputWeights; + const ConstTensorHandle* m_CellToForgetWeights; + const ConstTensorHandle* m_CellToOutputWeights; + const ConstTensorHandle* m_InputGateBias; + const ConstTensorHandle* m_ForgetGateBias; + const ConstTensorHandle* m_CellBias; + const ConstTensorHandle* m_OutputGateBias; + const ConstTensorHandle* m_ProjectionWeights; + const ConstTensorHandle* m_ProjectionBias; + const ConstTensorHandle* m_InputLayerNormWeights; + const ConstTensorHandle* m_ForgetLayerNormWeights; + const ConstTensorHandle* m_CellLayerNormWeights; + const ConstTensorHandle* m_OutputLayerNormWeights; + + void Validate(const WorkloadInfo& workloadInfo) const; +}; + } // namespace armnn diff --git a/src/backends/backendsCommon/WorkloadFactory.cpp b/src/backends/backendsCommon/WorkloadFactory.cpp index dc70e6a9c2..1c18551679 100644 --- a/src/backends/backendsCommon/WorkloadFactory.cpp +++ b/src/backends/backendsCommon/WorkloadFactory.cpp @@ -1277,6 +1277,147 @@ bool IWorkloadFactory::IsLayerConfigurationSupported(const BackendId& backendId, reason); break; } + case LayerType::UnidirectionalSequenceLstm: + { + auto cLayer = PolymorphicDowncast(&layer); + const UnidirectionalSequenceLstmDescriptor& descriptor = cLayer->GetParameters(); + + // All inputs. + const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(), + dataType); + const TensorInfo& outputStateIn = OverrideDataType(layer.GetInputSlot(1).GetConnection()->GetTensorInfo(), + dataType); + const TensorInfo& cellStateIn = OverrideDataType(layer.GetInputSlot(2).GetConnection()->GetTensorInfo(), + dataType); + // Outputs + const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType); + + // Basic parameters + const TensorInfo& inputToForgetWeights + = OverrideDataType(cLayer->m_BasicParameters.m_InputToForgetWeights->GetTensorInfo(), dataType); + const TensorInfo& inputToCellWeights + = OverrideDataType(cLayer->m_BasicParameters.m_InputToCellWeights->GetTensorInfo(), dataType); + const TensorInfo& inputToOutputWeights + = OverrideDataType(cLayer->m_BasicParameters.m_InputToOutputWeights->GetTensorInfo(), dataType); + const TensorInfo& recurrentToForgetWeights + = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo(), dataType); + const TensorInfo& recurrentToCellWeights + = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo(), dataType); + const TensorInfo& recurrentToOutputWeights + = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo(), dataType); + const TensorInfo& forgetGateBias + = OverrideDataType(cLayer->m_BasicParameters.m_ForgetGateBias->GetTensorInfo(), dataType); + const TensorInfo& cellBias + = OverrideDataType(cLayer->m_BasicParameters.m_CellBias->GetTensorInfo(), dataType); + const TensorInfo& outputGateBias + = OverrideDataType(cLayer->m_BasicParameters.m_OutputGateBias->GetTensorInfo(), dataType); + + LstmInputParamsInfo paramsInfo; + + paramsInfo.m_InputToForgetWeights = &inputToForgetWeights; + paramsInfo.m_InputToCellWeights = &inputToCellWeights; + paramsInfo.m_InputToOutputWeights = &inputToOutputWeights; + paramsInfo.m_RecurrentToForgetWeights = &recurrentToForgetWeights; + paramsInfo.m_RecurrentToCellWeights = &recurrentToCellWeights; + paramsInfo.m_RecurrentToOutputWeights = &recurrentToOutputWeights; + paramsInfo.m_ForgetGateBias = &forgetGateBias; + paramsInfo.m_CellBias = &cellBias; + paramsInfo.m_OutputGateBias = &outputGateBias; + + // Optional parameters + TensorInfo optInputToInputWeights; + TensorInfo optRecurrentToInputWeights; + TensorInfo optCellToInputWeights; + TensorInfo optInputGateBias; + TensorInfo optProjectionWeights; + TensorInfo optProjectionBias; + TensorInfo optCellToForgetWeights; + TensorInfo optCellToOutputWeights; + TensorInfo optInputLayerNormWeights; + TensorInfo optForgetLayerNormWeights; + TensorInfo optCellLayerNormWeights; + TensorInfo optOutputLayerNormWeights; + + if(!descriptor.m_CifgEnabled) + { + optInputToInputWeights = + OverrideDataType(cLayer->m_CifgParameters.m_InputToInputWeights->GetTensorInfo(), dataType); + paramsInfo.m_InputToInputWeights = &optInputToInputWeights; + + optRecurrentToInputWeights = + OverrideDataType(cLayer->m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo(), dataType); + paramsInfo.m_RecurrentToInputWeights = &optRecurrentToInputWeights; + optInputGateBias = + OverrideDataType(cLayer->m_CifgParameters.m_InputGateBias->GetTensorInfo(), dataType); + paramsInfo.m_InputGateBias = &optInputGateBias; + } + + if(descriptor.m_ProjectionEnabled) + { + optProjectionWeights = + OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo(), dataType); + paramsInfo.m_ProjectionWeights = &optProjectionWeights; + if (cLayer->m_ProjectionParameters.m_ProjectionBias != nullptr) + { + optProjectionBias = + OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionBias->GetTensorInfo(), dataType); + paramsInfo.m_ProjectionBias = &optProjectionBias; + } + } + + if(descriptor.m_PeepholeEnabled) + { + if(!descriptor.m_CifgEnabled) + { + optCellToInputWeights = + OverrideDataType(cLayer->m_PeepholeParameters.m_CellToInputWeights->GetTensorInfo(), + dataType); + paramsInfo.m_CellToInputWeights = &optCellToInputWeights; + } + optCellToForgetWeights = + OverrideDataType(cLayer->m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo(), dataType); + paramsInfo.m_CellToForgetWeights = &optCellToForgetWeights; + optCellToOutputWeights = + OverrideDataType(cLayer->m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo(), dataType); + paramsInfo.m_CellToOutputWeights = &optCellToOutputWeights; + } + + if(descriptor.m_LayerNormEnabled) + { + if (!descriptor.m_CifgEnabled) + { + optInputLayerNormWeights = OverrideDataType( + cLayer->m_LayerNormParameters.m_InputLayerNormWeights->GetTensorInfo(), dataType); + paramsInfo.m_InputLayerNormWeights = &optInputLayerNormWeights; + } + + optForgetLayerNormWeights = OverrideDataType( + cLayer->m_LayerNormParameters.m_ForgetLayerNormWeights->GetTensorInfo(), dataType); + paramsInfo.m_ForgetLayerNormWeights = &optForgetLayerNormWeights; + + optCellLayerNormWeights = OverrideDataType( + cLayer->m_LayerNormParameters.m_CellLayerNormWeights->GetTensorInfo(), dataType); + paramsInfo.m_CellLayerNormWeights = &optCellLayerNormWeights; + + optOutputLayerNormWeights = OverrideDataType( + cLayer->m_LayerNormParameters.m_OutputLayerNormWeights->GetTensorInfo(), dataType); + paramsInfo.m_OutputLayerNormWeights = &optOutputLayerNormWeights; + } + + Optional hiddenStateOut; + Optional cellStateOut; + + result = layerSupportObject.IsUnidirectionalSequenceLstmSupported(input, + outputStateIn, + cellStateIn, + output, + hiddenStateOut, + cellStateOut, + descriptor, + paramsInfo, + reason); + break; + } default: { ARMNN_ASSERT_MSG(false, "WorkloadFactory did not recognise type of layer."); @@ -1759,4 +1900,11 @@ std::unique_ptr IWorkloadFactory::CreateTransposeConvolution2d( return std::unique_ptr(); } +std::unique_ptr IWorkloadFactory::CreateUnidirectionalSequenceLstm( + const UnidirectionalSequenceLstmQueueDescriptor& /*descriptor*/, + const WorkloadInfo& /*info*/) const +{ + return std::unique_ptr(); +} + } // namepsace armnn diff --git a/src/backends/backendsCommon/WorkloadFactory.hpp b/src/backends/backendsCommon/WorkloadFactory.hpp index 1987b9b664..efb8d99fa0 100644 --- a/src/backends/backendsCommon/WorkloadFactory.hpp +++ b/src/backends/backendsCommon/WorkloadFactory.hpp @@ -289,6 +289,10 @@ public: const TransposeConvolution2dQueueDescriptor& descriptor, const WorkloadInfo& info) const; + virtual std::unique_ptr CreateUnidirectionalSequenceLstm( + const UnidirectionalSequenceLstmQueueDescriptor& descriptor, + const WorkloadInfo& info) const; + private: static bool IsLayerConfigurationSupported(const BackendId& backendId, const IConnectableLayer& connectableLayer, diff --git a/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp b/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp index ddd6eacb6d..21b33d297b 100644 --- a/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp +++ b/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp @@ -342,6 +342,56 @@ struct DummyLayer { }; +template +struct DummyUnidirectionalSequenceLstmLayer +{ + DummyUnidirectionalSequenceLstmLayer() + { + typename UnidirectionalSequenceLstmLayerType::DescriptorType desc; + desc.m_CifgEnabled = false; + + m_Layer = dummyGraph.AddLayer(desc, ""); + m_Layer->m_BasicParameters.m_InputToForgetWeights = std::make_unique( + armnn::TensorInfo(armnn::TensorShape({1,1,1,1}), armnn::DataType::Float32)); + m_Layer->m_BasicParameters.m_InputToCellWeights = std::make_unique( + armnn::TensorInfo(armnn::TensorShape({1,1,1,1}), armnn::DataType::Float32)); + m_Layer->m_BasicParameters.m_InputToOutputWeights = std::make_unique( + armnn::TensorInfo(armnn::TensorShape({1,1,1,1}), armnn::DataType::Float32)); + m_Layer->m_BasicParameters.m_RecurrentToForgetWeights = std::make_unique( + armnn::TensorInfo(armnn::TensorShape({1,1,1,1}), armnn::DataType::Float32)); + m_Layer->m_BasicParameters.m_RecurrentToCellWeights = std::make_unique( + armnn::TensorInfo(armnn::TensorShape({1,1,1,1}), armnn::DataType::Float32)); + m_Layer->m_BasicParameters.m_RecurrentToOutputWeights = std::make_unique( + armnn::TensorInfo(armnn::TensorShape({1,1,1,1}), armnn::DataType::Float32)); + m_Layer->m_BasicParameters.m_ForgetGateBias = std::make_unique( + armnn::TensorInfo(armnn::TensorShape({1,1,1,1}), armnn::DataType::Float32)); + m_Layer->m_BasicParameters.m_CellBias = std::make_unique( + armnn::TensorInfo(armnn::TensorShape({1,1,1,1}), armnn::DataType::Float32)); + m_Layer->m_BasicParameters.m_OutputGateBias = std::make_unique( + armnn::TensorInfo(armnn::TensorShape({1,1,1,1}), armnn::DataType::Float32)); + + m_Layer->m_CifgParameters.m_InputToInputWeights = std::make_unique( + armnn::TensorInfo(armnn::TensorShape({1,1,1,1}), armnn::DataType::Float32)); + m_Layer->m_CifgParameters.m_RecurrentToInputWeights = std::make_unique( + armnn::TensorInfo(armnn::TensorShape({1,1,1,1}), armnn::DataType::Float32)); + m_Layer->m_CifgParameters.m_InputGateBias = std::make_unique( + armnn::TensorInfo(armnn::TensorShape({1,1,1,1}), armnn::DataType::Float32)); + } + + ~DummyUnidirectionalSequenceLstmLayer() + { + dummyGraph.EraseLayer(m_Layer); + } + + armnn::UnidirectionalSequenceLstmLayer* m_Layer; +}; + +template<> +struct DummyLayer + : public DummyUnidirectionalSequenceLstmLayer +{ +}; + template<> struct DummyLayer { @@ -651,6 +701,7 @@ DECLARE_LAYER_POLICY_2_PARAM(Pooling2d) DECLARE_LAYER_POLICY_2_PARAM(PreCompiled) DECLARE_LAYER_POLICY_1_PARAM(Prelu) + DECLARE_LAYER_POLICY_2_PARAM(QLstm) DECLARE_LAYER_POLICY_1_PARAM(QuantizedLstm) @@ -691,6 +742,8 @@ DECLARE_LAYER_POLICY_2_PARAM(Transpose) DECLARE_LAYER_POLICY_2_PARAM(TransposeConvolution2d) +DECLARE_LAYER_POLICY_2_PARAM(UnidirectionalSequenceLstm) + DECLARE_LAYER_POLICY_MAP_PARAM(Unmap, void) -- cgit v1.2.1