From e5339e7013cf24e5a34509fb0a60377e5f8a244e Mon Sep 17 00:00:00 2001 From: Narumol Prangnawarat Date: Wed, 28 Jul 2021 17:33:28 +0100 Subject: MLCE-530 Add support for UnidirectionalSequenceLstm to RefWorkload * Add implementation of IsUnidirectionalSequenceLstmSupported to RefLayerSupport * Add RefUnidirectionalSequenceLstmWorkload * Refactor Lstm to be able to use for Lstm and SequenceLstm * Unit tests Signed-off-by: Narumol Prangnawarat Change-Id: Ibc066d213213a11b955dfefbe518de643298ba0c --- src/backends/reference/RefLayerSupport.cpp | 147 +++++++++++++++++++++++++++++ 1 file changed, 147 insertions(+) (limited to 'src/backends/reference/RefLayerSupport.cpp') diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp index 1b05c4e0f4..2603371927 100644 --- a/src/backends/reference/RefLayerSupport.cpp +++ b/src/backends/reference/RefLayerSupport.cpp @@ -1242,6 +1242,7 @@ bool RefLayerSupport::IsLstmSupported(const TensorInfo& input, "Reference Lstm: input and outputStateOut types are mismatched"); supported &= CheckSupportRule(TypesAreEqual(input, cellStateOut), reasonIfUnsupported, "Reference Lstm: input and cellStateOut types are mismatched"); + supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported, "Reference Lstm: input and output types are mismatched"); // check layer parameters @@ -2288,4 +2289,150 @@ bool RefLayerSupport::IsTransposeSupported(const TensorInfo& input, return supported; } +bool RefLayerSupport::IsUnidirectionalSequenceLstmSupported( + const TensorInfo& input, + const TensorInfo& outputStateIn, + const TensorInfo& cellStateIn, + const TensorInfo& output, + const Optional& hiddenStateOutput, + const Optional& cellStateOutput, + const UnidirectionalSequenceLstmDescriptor& descriptor, + const LstmInputParamsInfo& paramsInfo, + Optional reasonIfUnsupported) const +{ + IgnoreUnused(descriptor); + IgnoreUnused(paramsInfo); + IgnoreUnused(outputStateIn); + IgnoreUnused(cellStateIn); + bool supported = true; + + if (hiddenStateOutput.has_value() || cellStateOutput.has_value()) + { + reasonIfUnsupported.value() += "Reference UnidirectionalSequenceLstm: hidden state output " + "and cell state output are not supported at the moment."; + } + + std::array supportedTypes = + { + DataType::Float32 + }; + + std::array supportedWeightTypes = + { + DataType::Float32 + }; + + // check inputs and outputs + supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported, + "Reference UnidirectionalSequenceLstm: input is not a supported type."); + supported &= CheckSupportRule(TypesAreEqual(input, outputStateIn), reasonIfUnsupported, + "Reference UnidirectionalSequenceLstm: input and outputStateIn types are mismatched"); + supported &= CheckSupportRule(TypesAreEqual(input, cellStateIn), reasonIfUnsupported, + "Reference UnidirectionalSequenceLstm: input and cellStateIn types are mismatched"); + + supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported, + "Reference UnidirectionalSequenceLstm: input and output types are mismatched"); + // check layer parameters + supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToForgetWeights(), supportedWeightTypes), + reasonIfUnsupported, + "Reference UnidirectionalSequenceLstm: InputToForgetWeights " + "is not a supported type."); + supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToCellWeights(), supportedWeightTypes), + reasonIfUnsupported, + "Reference UnidirectionalSequenceLstm: InputToCellWeights is not a supported type."); + supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToOutputWeights(), supportedWeightTypes), + reasonIfUnsupported, + "Reference UnidirectionalSequenceLstm: InputToOutputWeights " + "is not a supported type."); + supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToForgetWeights(), supportedWeightTypes), + reasonIfUnsupported, + "Reference UnidirectionalSequenceLstm: RecurrentToForgetWeights " + "is not a supported type."); + supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToCellWeights(), supportedWeightTypes), + reasonIfUnsupported, + "Reference UnidirectionalSequenceLstm: RecurrentToCellWeights " + "is not a supported type."); + supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToOutputWeights(), supportedWeightTypes), + reasonIfUnsupported, + "Reference UnidirectionalSequenceLstm: RecurrentToOutputWeights " + "is not a supported type."); + supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetGateBias()), reasonIfUnsupported, + "Reference UnidirectionalSequenceLstm: input and ForgetGateBias types " + "are mismatched"); + supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellBias()), reasonIfUnsupported, + "Reference UnidirectionalSequenceLstm: input and CellBias types are mismatched"); + supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputGateBias()), reasonIfUnsupported, + "Reference UnidirectionalSequenceLstm: input and OutputGateBias types " + "are mismatched"); + if (!descriptor.m_CifgEnabled) + { + supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToInputWeights(), supportedWeightTypes), + reasonIfUnsupported, + "Reference UnidirectionalSequenceLstm: InputToInputWeights " + "is not a supported type."); + supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetRecurrentToInputWeights(), supportedWeightTypes), + reasonIfUnsupported, + "Reference UnidirectionalSequenceLstm: RecurrentToInputWeights " + "is not a supported type."); + supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputGateBias()), reasonIfUnsupported, + "Reference UnidirectionalSequenceLstm: input and InputGateBias types " + "are mismatched"); + if (descriptor.m_PeepholeEnabled) + { + supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellToInputWeights(), supportedWeightTypes), + reasonIfUnsupported, + "Reference UnidirectionalSequenceLstm: CellToInputWeights " + "is not a supported type."); + } + } + if (descriptor.m_PeepholeEnabled) + { + supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellToForgetWeights(), supportedWeightTypes), + reasonIfUnsupported, + "Reference UnidirectionalSequenceLstm: CellToForgetWeights " + "is not a supported type."); + supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellToOutputWeights(), supportedWeightTypes), + reasonIfUnsupported, + "Reference UnidirectionalSequenceLstm: CellToOutputWeights " + "is not a supported type."); + } + if (descriptor.m_ProjectionEnabled) + { + supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetProjectionWeights(), supportedWeightTypes), + reasonIfUnsupported, + "Reference UnidirectionalSequenceLstm: ProjectionWeights " + "is not a supported type."); + if (paramsInfo.m_ProjectionBias != nullptr) + { + supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetProjectionBias()), reasonIfUnsupported, + "Reference UnidirectionalSequenceLstm: input and ProjectionBias types " + "are mismatched"); + } + } + if (descriptor.m_LayerNormEnabled) + { + if (!descriptor.m_CifgEnabled) + { + supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputLayerNormWeights(), supportedWeightTypes), + reasonIfUnsupported, + "Reference UnidirectionalSequenceLstm: InputLayerNormWeights " + "is not a supported type."); + } + supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetForgetLayerNormWeights(), supportedWeightTypes), + reasonIfUnsupported, + "Reference UnidirectionalSequenceLstm: ForgetLayerNormWeights " + "is not a supported type."); + supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellLayerNormWeights(), supportedWeightTypes), + reasonIfUnsupported, + "Reference UnidirectionalSequenceLstm: CellLayerNormWeights " + "is not a supported type."); + supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetOutputLayerNormWeights(), supportedWeightTypes), + reasonIfUnsupported, + "Reference UnidirectionalSequenceLstm: OutputLayerNormWeights " + "is not a supported type."); + } + + return supported; +} + } // namespace armnn -- cgit v1.2.1