From 1299496996bc332f02218f926640a9255ed60310 Mon Sep 17 00:00:00 2001 From: Mike Kelly Date: Thu, 21 Apr 2022 11:57:09 +0100 Subject: IVGCVSW-6806 Add Unidirectional Sequence Lstm support to Neon * Corrected TensorInfo order for IsUnidirectionalSequenceLstmSupported * outputStateOut TensorInfo is not optional. * cellStateOut TensorInfo is not optional. * TensorInfo Order matches other QLSTM/LSTM layers. * Added missing parameters to UnidirectionalSequenceLstmOperator for delegate. * Added quantized UnidirectionalSequenceLstm support to Neon !android-nn-driver:7457 Signed-off-by: Mike Kelly Change-Id: I26dde1bb96793dd25eb9081ca5ae5f63752288c4 --- src/backends/reference/RefLayerSupport.cpp | 112 +++++++-------------- src/backends/reference/RefLayerSupport.hpp | 4 +- .../RefUnidirectionalSequenceLstmWorkload.cpp | 8 +- 3 files changed, 41 insertions(+), 83 deletions(-) (limited to 'src/backends/reference') diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp index 66661cb521..919c6db6ff 100644 --- a/src/backends/reference/RefLayerSupport.cpp +++ b/src/backends/reference/RefLayerSupport.cpp @@ -465,57 +465,15 @@ bool RefLayerSupport::IsLayerSupported(const LayerType& type, "hiddenStateOutputVal, cellStateOutputVal, output}"); } auto desc = *(PolymorphicDowncast(&descriptor)); - - bool isHiddenStateOutputOptional = (infos[4] == TensorInfo()); - bool isCellStateOutput = (infos[5] == TensorInfo()); - if (isHiddenStateOutputOptional && isCellStateOutput) - { - return IsUnidirectionalSequenceLstmSupported(infos[0], - infos[1], - infos[2], - infos[3], - EmptyOptional(), - EmptyOptional(), - desc, - lstmParamsInfo.value(), - reasonIfUnsupported); - } - else if (isHiddenStateOutputOptional) - { - return IsUnidirectionalSequenceLstmSupported(infos[0], - infos[1], - infos[2], - infos[3], - EmptyOptional(), - infos[5], - desc, - lstmParamsInfo.value(), - reasonIfUnsupported); - } - else if (isCellStateOutput) - { - return IsUnidirectionalSequenceLstmSupported(infos[0], - infos[1], - infos[2], - infos[3], - infos[4], - EmptyOptional(), - desc, - lstmParamsInfo.value(), - reasonIfUnsupported); - } - else - { - return IsUnidirectionalSequenceLstmSupported(infos[0], - infos[1], - infos[2], - infos[3], - infos[4], - infos[5], - desc, - lstmParamsInfo.value(), - reasonIfUnsupported); - } + return IsUnidirectionalSequenceLstmSupported(infos[0], + infos[1], + infos[2], + infos[3], + infos[4], + infos[5], + desc, + lstmParamsInfo.value(), + reasonIfUnsupported); } case LayerType::Pooling3d: return IsPooling3dSupported(infos[0], @@ -2841,9 +2799,9 @@ bool RefLayerSupport::IsUnidirectionalSequenceLstmSupported( const TensorInfo& input, const TensorInfo& outputStateIn, const TensorInfo& cellStateIn, + const TensorInfo& outputStateOut, + const TensorInfo& cellStateOut, const TensorInfo& output, - const Optional& hiddenStateOutput, - const Optional& cellStateOutput, const UnidirectionalSequenceLstmDescriptor& descriptor, const LstmInputParamsInfo& paramsInfo, Optional reasonIfUnsupported) const @@ -2852,17 +2810,14 @@ bool RefLayerSupport::IsUnidirectionalSequenceLstmSupported( IgnoreUnused(paramsInfo); IgnoreUnused(outputStateIn); IgnoreUnused(cellStateIn); + IgnoreUnused(outputStateOut); + IgnoreUnused(cellStateOut); bool supported = true; - if (hiddenStateOutput.has_value() || cellStateOutput.has_value()) + std::array supportedTypes = { - reasonIfUnsupported.value() += "Reference UnidirectionalSequenceLstm: hidden state output " - "and cell state output are not supported at the moment."; - } - - std::array supportedTypes = - { - DataType::Float32 + DataType::Float32, + DataType::QAsymmS8 }; std::array supportedWeightTypes = @@ -2871,16 +2826,19 @@ bool RefLayerSupport::IsUnidirectionalSequenceLstmSupported( DataType::QAsymmS8 }; + std::array supportedBiasTypes = + { + DataType::Float32, + DataType::QAsymmS8, + DataType::Signed32 + }; + // check inputs and outputs supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported, "Reference UnidirectionalSequenceLstm: input is not a supported type."); - supported &= CheckSupportRule(TypesAreEqual(input, outputStateIn), reasonIfUnsupported, - "Reference UnidirectionalSequenceLstm: input and outputStateIn types are mismatched"); - supported &= CheckSupportRule(TypesAreEqual(input, cellStateIn), reasonIfUnsupported, - "Reference UnidirectionalSequenceLstm: input and cellStateIn types are mismatched"); + supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported, + "Reference UnidirectionalSequenceLstm: output is not a supported type."); - supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported, - "Reference UnidirectionalSequenceLstm: input and output types are mismatched"); // check layer parameters supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToForgetWeights(), supportedWeightTypes), reasonIfUnsupported, @@ -2905,14 +2863,13 @@ bool RefLayerSupport::IsUnidirectionalSequenceLstmSupported( reasonIfUnsupported, "Reference UnidirectionalSequenceLstm: RecurrentToOutputWeights " "is not a supported type."); - supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetGateBias()), reasonIfUnsupported, - "Reference UnidirectionalSequenceLstm: input and ForgetGateBias types " - "are mismatched"); - supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellBias()), reasonIfUnsupported, - "Reference UnidirectionalSequenceLstm: input and CellBias types are mismatched"); - supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputGateBias()), reasonIfUnsupported, - "Reference UnidirectionalSequenceLstm: input and OutputGateBias types " - "are mismatched"); + + supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetForgetGateBias(), supportedBiasTypes), reasonIfUnsupported, + "Reference UnidirectionalSequenceLstm: ForgetGateBias is not a supported type."); + supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellBias(), supportedBiasTypes), reasonIfUnsupported, + "Reference UnidirectionalSequenceLstm: CellBias is not a supported type."); + supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetOutputGateBias(), supportedBiasTypes), reasonIfUnsupported, + "Reference UnidirectionalSequenceLstm: OutputGateBias is not a supported type."); if (!descriptor.m_CifgEnabled) { supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToInputWeights(), supportedWeightTypes), @@ -2923,9 +2880,8 @@ bool RefLayerSupport::IsUnidirectionalSequenceLstmSupported( reasonIfUnsupported, "Reference UnidirectionalSequenceLstm: RecurrentToInputWeights " "is not a supported type."); - supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputGateBias()), reasonIfUnsupported, - "Reference UnidirectionalSequenceLstm: input and InputGateBias types " - "are mismatched"); + supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputGateBias(), supportedBiasTypes), reasonIfUnsupported, + "Reference UnidirectionalSequenceLstm: InputGateBias is not a supported type."); if (descriptor.m_PeepholeEnabled) { supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellToInputWeights(), supportedWeightTypes), diff --git a/src/backends/reference/RefLayerSupport.hpp b/src/backends/reference/RefLayerSupport.hpp index 98770ad64a..aa8bd8dda4 100644 --- a/src/backends/reference/RefLayerSupport.hpp +++ b/src/backends/reference/RefLayerSupport.hpp @@ -367,9 +367,9 @@ public: const TensorInfo& input, const TensorInfo& outputStateIn, const TensorInfo& cellStateIn, + const TensorInfo& outputStateOut, + const TensorInfo& cellStateOut, const TensorInfo& output, - const Optional& hiddenStateOutput, - const Optional& cellStateOutput, const UnidirectionalSequenceLstmDescriptor& descriptor, const LstmInputParamsInfo& paramsInfo, Optional reasonIfUnsupported = EmptyOptional()) const override; diff --git a/src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.cpp b/src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.cpp index d447a46b23..c4345d4978 100644 --- a/src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.cpp +++ b/src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.cpp @@ -59,7 +59,9 @@ void RefUnidirectionalSequenceLstmWorkload::Execute(std::vector TensorInfo inputInfo = GetTensorInfo(inputs[0]); const TensorInfo& outputStateInfo = GetTensorInfo(inputs[1]); const TensorInfo& cellStateInfo = GetTensorInfo(inputs[2]); - TensorInfo outputInfo = GetTensorInfo(outputs[0]); + TensorInfo outputStateOutInfo = GetTensorInfo(outputs[0]); + TensorInfo cellStateOutInfo = GetTensorInfo(outputs[1]); + TensorInfo outputInfo = GetTensorInfo(outputs[2]); TensorShape& inputShape = inputInfo.GetShape(); TensorShape& outputShape= outputInfo.GetShape(); auto inputTensor = reinterpret_cast(inputs[0]->Map()); @@ -140,7 +142,7 @@ void RefUnidirectionalSequenceLstmWorkload::Execute(std::vector auto currentInputData = reinterpret_cast(inputs[0]->Map()); std::unique_ptr> inputData = MakeDecoder(lstmInputInfo, currentInputData); - auto currentOutputData = reinterpret_cast(outputs[0]->Map()); + auto currentOutputData = reinterpret_cast(outputs[2]->Map()); std::unique_ptr> output = MakeEncoder(lstmOutputInfo, currentOutputData); std::unique_ptr> outputDecoder = MakeDecoder(lstmOutputInfo, currentOutputData); @@ -296,7 +298,7 @@ void RefUnidirectionalSequenceLstmWorkload::Execute(std::vector { // Permute Output back to batch major const PermutationVector& mappings = {1U, 0U, 2U}; - auto outputData = reinterpret_cast(outputs[0]->Map()); + auto outputData = reinterpret_cast(outputs[2]->Map()); std::vector outputValue(outputData, outputData + outputInfo.GetNumElements()); outputShape = armnnUtils::Permuted(outputInfo.GetShape(), mappings); outputInfo.SetShape(outputShape); -- cgit v1.2.1