diff options
author | Mike Kelly <mike.kelly@arm.com> | 2022-04-21 11:57:09 +0100 |
---|---|---|
committer | mike.kelly <mike.kelly@arm.com> | 2022-05-05 08:29:20 +0000 |
commit | 1299496996bc332f02218f926640a9255ed60310 (patch) | |
tree | 2d242e142bd8fe7387140bcf8cdf39cd13ffc9eb /src/backends/reference | |
parent | 8272a7bda2974c39b6c06e3eb3a000f2bdb749f7 (diff) | |
download | armnn-1299496996bc332f02218f926640a9255ed60310.tar.gz |
IVGCVSW-6806 Add Unidirectional Sequence Lstm support to Neon
* Corrected TensorInfo order for IsUnidirectionalSequenceLstmSupported
* outputStateOut TensorInfo is not optional.
* cellStateOut TensorInfo is not optional.
* TensorInfo Order matches other QLSTM/LSTM layers.
* Added missing parameters to UnidirectionalSequenceLstmOperator for
delegate.
* Added quantized UnidirectionalSequenceLstm support to Neon
!android-nn-driver:7457
Signed-off-by: Mike Kelly <mike.kelly@arm.com>
Change-Id: I26dde1bb96793dd25eb9081ca5ae5f63752288c4
Diffstat (limited to 'src/backends/reference')
-rw-r--r-- | src/backends/reference/RefLayerSupport.cpp | 112 | ||||
-rw-r--r-- | src/backends/reference/RefLayerSupport.hpp | 4 | ||||
-rw-r--r-- | src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.cpp | 8 |
3 files changed, 41 insertions, 83 deletions
diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp index 66661cb521..919c6db6ff 100644 --- a/src/backends/reference/RefLayerSupport.cpp +++ b/src/backends/reference/RefLayerSupport.cpp @@ -465,57 +465,15 @@ bool RefLayerSupport::IsLayerSupported(const LayerType& type, "hiddenStateOutputVal, cellStateOutputVal, output}"); } auto desc = *(PolymorphicDowncast<const UnidirectionalSequenceLstmDescriptor*>(&descriptor)); - - bool isHiddenStateOutputOptional = (infos[4] == TensorInfo()); - bool isCellStateOutput = (infos[5] == TensorInfo()); - if (isHiddenStateOutputOptional && isCellStateOutput) - { - return IsUnidirectionalSequenceLstmSupported(infos[0], - infos[1], - infos[2], - infos[3], - EmptyOptional(), - EmptyOptional(), - desc, - lstmParamsInfo.value(), - reasonIfUnsupported); - } - else if (isHiddenStateOutputOptional) - { - return IsUnidirectionalSequenceLstmSupported(infos[0], - infos[1], - infos[2], - infos[3], - EmptyOptional(), - infos[5], - desc, - lstmParamsInfo.value(), - reasonIfUnsupported); - } - else if (isCellStateOutput) - { - return IsUnidirectionalSequenceLstmSupported(infos[0], - infos[1], - infos[2], - infos[3], - infos[4], - EmptyOptional(), - desc, - lstmParamsInfo.value(), - reasonIfUnsupported); - } - else - { - return IsUnidirectionalSequenceLstmSupported(infos[0], - infos[1], - infos[2], - infos[3], - infos[4], - infos[5], - desc, - lstmParamsInfo.value(), - reasonIfUnsupported); - } + return IsUnidirectionalSequenceLstmSupported(infos[0], + infos[1], + infos[2], + infos[3], + infos[4], + infos[5], + desc, + lstmParamsInfo.value(), + reasonIfUnsupported); } case LayerType::Pooling3d: return IsPooling3dSupported(infos[0], @@ -2841,9 +2799,9 @@ bool RefLayerSupport::IsUnidirectionalSequenceLstmSupported( const TensorInfo& input, const TensorInfo& outputStateIn, const TensorInfo& cellStateIn, + const TensorInfo& outputStateOut, + const TensorInfo& cellStateOut, const TensorInfo& output, - const Optional<TensorInfo>& hiddenStateOutput, - const Optional<TensorInfo>& cellStateOutput, const UnidirectionalSequenceLstmDescriptor& descriptor, const LstmInputParamsInfo& paramsInfo, Optional<std::string&> reasonIfUnsupported) const @@ -2852,17 +2810,14 @@ bool RefLayerSupport::IsUnidirectionalSequenceLstmSupported( IgnoreUnused(paramsInfo); IgnoreUnused(outputStateIn); IgnoreUnused(cellStateIn); + IgnoreUnused(outputStateOut); + IgnoreUnused(cellStateOut); bool supported = true; - if (hiddenStateOutput.has_value() || cellStateOutput.has_value()) + std::array<DataType, 2> supportedTypes = { - reasonIfUnsupported.value() += "Reference UnidirectionalSequenceLstm: hidden state output " - "and cell state output are not supported at the moment."; - } - - std::array<DataType, 1> supportedTypes = - { - DataType::Float32 + DataType::Float32, + DataType::QAsymmS8 }; std::array<DataType, 2> supportedWeightTypes = @@ -2871,16 +2826,19 @@ bool RefLayerSupport::IsUnidirectionalSequenceLstmSupported( DataType::QAsymmS8 }; + std::array<DataType, 3> supportedBiasTypes = + { + DataType::Float32, + DataType::QAsymmS8, + DataType::Signed32 + }; + // check inputs and outputs supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported, "Reference UnidirectionalSequenceLstm: input is not a supported type."); - supported &= CheckSupportRule(TypesAreEqual(input, outputStateIn), reasonIfUnsupported, - "Reference UnidirectionalSequenceLstm: input and outputStateIn types are mismatched"); - supported &= CheckSupportRule(TypesAreEqual(input, cellStateIn), reasonIfUnsupported, - "Reference UnidirectionalSequenceLstm: input and cellStateIn types are mismatched"); + supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported, + "Reference UnidirectionalSequenceLstm: output is not a supported type."); - supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported, - "Reference UnidirectionalSequenceLstm: input and output types are mismatched"); // check layer parameters supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToForgetWeights(), supportedWeightTypes), reasonIfUnsupported, @@ -2905,14 +2863,13 @@ bool RefLayerSupport::IsUnidirectionalSequenceLstmSupported( reasonIfUnsupported, "Reference UnidirectionalSequenceLstm: RecurrentToOutputWeights " "is not a supported type."); - supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetGateBias()), reasonIfUnsupported, - "Reference UnidirectionalSequenceLstm: input and ForgetGateBias types " - "are mismatched"); - supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellBias()), reasonIfUnsupported, - "Reference UnidirectionalSequenceLstm: input and CellBias types are mismatched"); - supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputGateBias()), reasonIfUnsupported, - "Reference UnidirectionalSequenceLstm: input and OutputGateBias types " - "are mismatched"); + + supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetForgetGateBias(), supportedBiasTypes), reasonIfUnsupported, + "Reference UnidirectionalSequenceLstm: ForgetGateBias is not a supported type."); + supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellBias(), supportedBiasTypes), reasonIfUnsupported, + "Reference UnidirectionalSequenceLstm: CellBias is not a supported type."); + supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetOutputGateBias(), supportedBiasTypes), reasonIfUnsupported, + "Reference UnidirectionalSequenceLstm: OutputGateBias is not a supported type."); if (!descriptor.m_CifgEnabled) { supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToInputWeights(), supportedWeightTypes), @@ -2923,9 +2880,8 @@ bool RefLayerSupport::IsUnidirectionalSequenceLstmSupported( reasonIfUnsupported, "Reference UnidirectionalSequenceLstm: RecurrentToInputWeights " "is not a supported type."); - supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputGateBias()), reasonIfUnsupported, - "Reference UnidirectionalSequenceLstm: input and InputGateBias types " - "are mismatched"); + supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputGateBias(), supportedBiasTypes), reasonIfUnsupported, + "Reference UnidirectionalSequenceLstm: InputGateBias is not a supported type."); if (descriptor.m_PeepholeEnabled) { supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellToInputWeights(), supportedWeightTypes), diff --git a/src/backends/reference/RefLayerSupport.hpp b/src/backends/reference/RefLayerSupport.hpp index 98770ad64a..aa8bd8dda4 100644 --- a/src/backends/reference/RefLayerSupport.hpp +++ b/src/backends/reference/RefLayerSupport.hpp @@ -367,9 +367,9 @@ public: const TensorInfo& input, const TensorInfo& outputStateIn, const TensorInfo& cellStateIn, + const TensorInfo& outputStateOut, + const TensorInfo& cellStateOut, const TensorInfo& output, - const Optional<TensorInfo>& hiddenStateOutput, - const Optional<TensorInfo>& cellStateOutput, const UnidirectionalSequenceLstmDescriptor& descriptor, const LstmInputParamsInfo& paramsInfo, Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override; diff --git a/src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.cpp b/src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.cpp index d447a46b23..c4345d4978 100644 --- a/src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.cpp +++ b/src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.cpp @@ -59,7 +59,9 @@ void RefUnidirectionalSequenceLstmWorkload::Execute(std::vector<ITensorHandle*> TensorInfo inputInfo = GetTensorInfo(inputs[0]); const TensorInfo& outputStateInfo = GetTensorInfo(inputs[1]); const TensorInfo& cellStateInfo = GetTensorInfo(inputs[2]); - TensorInfo outputInfo = GetTensorInfo(outputs[0]); + TensorInfo outputStateOutInfo = GetTensorInfo(outputs[0]); + TensorInfo cellStateOutInfo = GetTensorInfo(outputs[1]); + TensorInfo outputInfo = GetTensorInfo(outputs[2]); TensorShape& inputShape = inputInfo.GetShape(); TensorShape& outputShape= outputInfo.GetShape(); auto inputTensor = reinterpret_cast<float*>(inputs[0]->Map()); @@ -140,7 +142,7 @@ void RefUnidirectionalSequenceLstmWorkload::Execute(std::vector<ITensorHandle*> auto currentInputData = reinterpret_cast<float*>(inputs[0]->Map()); std::unique_ptr<Decoder<float>> inputData = MakeDecoder<float>(lstmInputInfo, currentInputData); - auto currentOutputData = reinterpret_cast<float*>(outputs[0]->Map()); + auto currentOutputData = reinterpret_cast<float*>(outputs[2]->Map()); std::unique_ptr<Encoder<float>> output = MakeEncoder<float>(lstmOutputInfo, currentOutputData); std::unique_ptr<Decoder<float>> outputDecoder = MakeDecoder<float>(lstmOutputInfo, currentOutputData); @@ -296,7 +298,7 @@ void RefUnidirectionalSequenceLstmWorkload::Execute(std::vector<ITensorHandle*> { // Permute Output back to batch major const PermutationVector& mappings = {1U, 0U, 2U}; - auto outputData = reinterpret_cast<float*>(outputs[0]->Map()); + auto outputData = reinterpret_cast<float*>(outputs[2]->Map()); std::vector<float> outputValue(outputData, outputData + outputInfo.GetNumElements()); outputShape = armnnUtils::Permuted(outputInfo.GetShape(), mappings); outputInfo.SetShape(outputShape); |