diff options
Diffstat (limited to 'src/backends/reference/RefLayerSupport.cpp')
-rw-r--r-- | src/backends/reference/RefLayerSupport.cpp | 112 |
1 files changed, 34 insertions, 78 deletions
diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp index 66661cb521..919c6db6ff 100644 --- a/src/backends/reference/RefLayerSupport.cpp +++ b/src/backends/reference/RefLayerSupport.cpp @@ -465,57 +465,15 @@ bool RefLayerSupport::IsLayerSupported(const LayerType& type, "hiddenStateOutputVal, cellStateOutputVal, output}"); } auto desc = *(PolymorphicDowncast<const UnidirectionalSequenceLstmDescriptor*>(&descriptor)); - - bool isHiddenStateOutputOptional = (infos[4] == TensorInfo()); - bool isCellStateOutput = (infos[5] == TensorInfo()); - if (isHiddenStateOutputOptional && isCellStateOutput) - { - return IsUnidirectionalSequenceLstmSupported(infos[0], - infos[1], - infos[2], - infos[3], - EmptyOptional(), - EmptyOptional(), - desc, - lstmParamsInfo.value(), - reasonIfUnsupported); - } - else if (isHiddenStateOutputOptional) - { - return IsUnidirectionalSequenceLstmSupported(infos[0], - infos[1], - infos[2], - infos[3], - EmptyOptional(), - infos[5], - desc, - lstmParamsInfo.value(), - reasonIfUnsupported); - } - else if (isCellStateOutput) - { - return IsUnidirectionalSequenceLstmSupported(infos[0], - infos[1], - infos[2], - infos[3], - infos[4], - EmptyOptional(), - desc, - lstmParamsInfo.value(), - reasonIfUnsupported); - } - else - { - return IsUnidirectionalSequenceLstmSupported(infos[0], - infos[1], - infos[2], - infos[3], - infos[4], - infos[5], - desc, - lstmParamsInfo.value(), - reasonIfUnsupported); - } + return IsUnidirectionalSequenceLstmSupported(infos[0], + infos[1], + infos[2], + infos[3], + infos[4], + infos[5], + desc, + lstmParamsInfo.value(), + reasonIfUnsupported); } case LayerType::Pooling3d: return IsPooling3dSupported(infos[0], @@ -2841,9 +2799,9 @@ bool RefLayerSupport::IsUnidirectionalSequenceLstmSupported( const TensorInfo& input, const TensorInfo& outputStateIn, const TensorInfo& cellStateIn, + const TensorInfo& outputStateOut, + const TensorInfo& cellStateOut, const TensorInfo& output, - const Optional<TensorInfo>& hiddenStateOutput, - const Optional<TensorInfo>& cellStateOutput, const UnidirectionalSequenceLstmDescriptor& descriptor, const LstmInputParamsInfo& paramsInfo, Optional<std::string&> reasonIfUnsupported) const @@ -2852,17 +2810,14 @@ bool RefLayerSupport::IsUnidirectionalSequenceLstmSupported( IgnoreUnused(paramsInfo); IgnoreUnused(outputStateIn); IgnoreUnused(cellStateIn); + IgnoreUnused(outputStateOut); + IgnoreUnused(cellStateOut); bool supported = true; - if (hiddenStateOutput.has_value() || cellStateOutput.has_value()) + std::array<DataType, 2> supportedTypes = { - reasonIfUnsupported.value() += "Reference UnidirectionalSequenceLstm: hidden state output " - "and cell state output are not supported at the moment."; - } - - std::array<DataType, 1> supportedTypes = - { - DataType::Float32 + DataType::Float32, + DataType::QAsymmS8 }; std::array<DataType, 2> supportedWeightTypes = @@ -2871,16 +2826,19 @@ bool RefLayerSupport::IsUnidirectionalSequenceLstmSupported( DataType::QAsymmS8 }; + std::array<DataType, 3> supportedBiasTypes = + { + DataType::Float32, + DataType::QAsymmS8, + DataType::Signed32 + }; + // check inputs and outputs supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported, "Reference UnidirectionalSequenceLstm: input is not a supported type."); - supported &= CheckSupportRule(TypesAreEqual(input, outputStateIn), reasonIfUnsupported, - "Reference UnidirectionalSequenceLstm: input and outputStateIn types are mismatched"); - supported &= CheckSupportRule(TypesAreEqual(input, cellStateIn), reasonIfUnsupported, - "Reference UnidirectionalSequenceLstm: input and cellStateIn types are mismatched"); + supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported, + "Reference UnidirectionalSequenceLstm: output is not a supported type."); - supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported, - "Reference UnidirectionalSequenceLstm: input and output types are mismatched"); // check layer parameters supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToForgetWeights(), supportedWeightTypes), reasonIfUnsupported, @@ -2905,14 +2863,13 @@ bool RefLayerSupport::IsUnidirectionalSequenceLstmSupported( reasonIfUnsupported, "Reference UnidirectionalSequenceLstm: RecurrentToOutputWeights " "is not a supported type."); - supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetForgetGateBias()), reasonIfUnsupported, - "Reference UnidirectionalSequenceLstm: input and ForgetGateBias types " - "are mismatched"); - supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetCellBias()), reasonIfUnsupported, - "Reference UnidirectionalSequenceLstm: input and CellBias types are mismatched"); - supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetOutputGateBias()), reasonIfUnsupported, - "Reference UnidirectionalSequenceLstm: input and OutputGateBias types " - "are mismatched"); + + supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetForgetGateBias(), supportedBiasTypes), reasonIfUnsupported, + "Reference UnidirectionalSequenceLstm: ForgetGateBias is not a supported type."); + supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellBias(), supportedBiasTypes), reasonIfUnsupported, + "Reference UnidirectionalSequenceLstm: CellBias is not a supported type."); + supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetOutputGateBias(), supportedBiasTypes), reasonIfUnsupported, + "Reference UnidirectionalSequenceLstm: OutputGateBias is not a supported type."); if (!descriptor.m_CifgEnabled) { supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputToInputWeights(), supportedWeightTypes), @@ -2923,9 +2880,8 @@ bool RefLayerSupport::IsUnidirectionalSequenceLstmSupported( reasonIfUnsupported, "Reference UnidirectionalSequenceLstm: RecurrentToInputWeights " "is not a supported type."); - supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.GetInputGateBias()), reasonIfUnsupported, - "Reference UnidirectionalSequenceLstm: input and InputGateBias types " - "are mismatched"); + supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetInputGateBias(), supportedBiasTypes), reasonIfUnsupported, + "Reference UnidirectionalSequenceLstm: InputGateBias is not a supported type."); if (descriptor.m_PeepholeEnabled) { supported &= CheckSupportRule(TypeAnyOf(paramsInfo.GetCellToInputWeights(), supportedWeightTypes), |