diff options
Diffstat (limited to 'src/backends/reference')
-rw-r--r-- | src/backends/reference/RefLayerSupport.cpp | 123 | ||||
-rw-r--r-- | src/backends/reference/RefLayerSupport.hpp | 24 |
2 files changed, 76 insertions, 71 deletions
diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp index ac7f310c1d..59c14c4490 100644 --- a/src/backends/reference/RefLayerSupport.cpp +++ b/src/backends/reference/RefLayerSupport.cpp @@ -924,51 +924,11 @@ bool RefLayerSupport::IsLstmSupported(const TensorInfo& input, const TensorInfo& cellStateOut, const TensorInfo& output, const LstmDescriptor& descriptor, - const TensorInfo& inputToForgetWeights, - const TensorInfo& inputToCellWeights, - const TensorInfo& inputToOutputWeights, - const TensorInfo& recurrentToForgetWeights, - const TensorInfo& recurrentToCellWeights, - const TensorInfo& recurrentToOutputWeights, - const TensorInfo& forgetGateBias, - const TensorInfo& cellBias, - const TensorInfo& outputGateBias, - const TensorInfo* inputToInputWeights, - const TensorInfo* recurrentToInputWeights, - const TensorInfo* cellToInputWeights, - const TensorInfo* inputGateBias, - const TensorInfo* projectionWeights, - const TensorInfo* projectionBias, - const TensorInfo* cellToForgetWeights, - const TensorInfo* cellToOutputWeights, - Optional<std::string&> reasonIfUnsupported, - const TensorInfo* inputLayerNormWeights, - const TensorInfo* forgetLayerNormWeights, - const TensorInfo* cellLayerNormWeights, - const TensorInfo* outputLayerNormWeights) const + const LstmInputParamsInfo& paramsInfo, + Optional<std::string&> reasonIfUnsupported) const { ignore_unused(descriptor); - ignore_unused(inputToForgetWeights); - ignore_unused(inputToCellWeights); - ignore_unused(inputToOutputWeights); - ignore_unused(recurrentToForgetWeights); - ignore_unused(recurrentToCellWeights); - ignore_unused(recurrentToOutputWeights); - ignore_unused(forgetGateBias); - ignore_unused(cellBias); - ignore_unused(outputGateBias); - ignore_unused(inputToInputWeights); - ignore_unused(recurrentToInputWeights); - ignore_unused(cellToInputWeights); - ignore_unused(inputGateBias); - ignore_unused(projectionWeights); - ignore_unused(projectionBias); - ignore_unused(cellToForgetWeights); - ignore_unused(cellToOutputWeights); - ignore_unused(inputLayerNormWeights); - ignore_unused(forgetLayerNormWeights); - ignore_unused(cellLayerNormWeights); - ignore_unused(outputLayerNormWeights); + ignore_unused(paramsInfo); bool supported = true; @@ -977,26 +937,91 @@ bool RefLayerSupport::IsLstmSupported(const TensorInfo& input, DataType::QuantisedSymm16 }; + // check inputs and outputs supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported, "Reference Lstm: input is not a supported type."); - supported &= CheckSupportRule(TypesAreEqual(input, outputStateIn), reasonIfUnsupported, "Reference Lstm: input and outputStateIn types are mismatched"); - supported &= CheckSupportRule(TypesAreEqual(input, cellStateIn), reasonIfUnsupported, "Reference Lstm: input and cellStateIn types are mismatched"); - supported &= CheckSupportRule(TypesAreEqual(input, scratchBuffer), reasonIfUnsupported, "Reference Lstm: input and scratchBuffer types are mismatched"); - supported &= CheckSupportRule(TypesAreEqual(input, outputStateOut), reasonIfUnsupported, "Reference Lstm: input and outputStateOut types are mismatched"); - supported &= CheckSupportRule(TypesAreEqual(input, cellStateOut), reasonIfUnsupported, "Reference Lstm: input and cellStateOut types are mismatched"); - supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported, "Reference Lstm: input and output types are mismatched"); + // check layer parameters + supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_InputToForgetWeights()), reasonIfUnsupported, + "Reference Lstm: input and InputToForgetWeights types are mismatched"); + supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_InputToCellWeights()), reasonIfUnsupported, + "Reference Lstm: input and InputToCellWeights types are mismatched"); + supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_InputToOutputWeights()), reasonIfUnsupported, + "Reference Lstm: input and InputToOutputWeights types are mismatched"); + supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_RecurrentToForgetWeights()), reasonIfUnsupported, + "Reference Lstm: input and RecurrentToForgetWeights types are mismatched"); + supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_RecurrentToCellWeights()), reasonIfUnsupported, + "Reference Lstm: input and RecurrentToCellWeights types are mismatched"); + supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_RecurrentToOutputWeights()), reasonIfUnsupported, + "Reference Lstm: input and RecurrentToOutputWeights types are mismatched"); + supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_ForgetGateBias()), reasonIfUnsupported, + "Reference Lstm: input and ForgetGateBias types are mismatched"); + supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_CellBias()), reasonIfUnsupported, + "Reference Lstm: input and CellBias types are mismatched"); + supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_OutputGateBias()), reasonIfUnsupported, + "Reference Lstm: input and OutputGateBias types are mismatched"); + if (!descriptor.m_CifgEnabled) + { + supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_InputToInputWeights()), reasonIfUnsupported, + "Reference Lstm: input and InputToInputWeights types are mismatched"); + supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_RecurrentToInputWeights()), + reasonIfUnsupported, + "Reference Lstm: input and RecurrentToInputWeights types are mismatched"); + supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_InputGateBias()), reasonIfUnsupported, + "Reference Lstm: input and InputGateBias types are mismatched"); + if (descriptor.m_PeepholeEnabled) + { + supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_CellToInputWeights()), + reasonIfUnsupported, + "Reference Lstm: input and CellToInputWeights types are mismatched"); + } + } + if (descriptor.m_PeepholeEnabled) + { + supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_CellToForgetWeights()), reasonIfUnsupported, + "Reference Lstm: input and CellToForgetWeights types are mismatched"); + supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_CellToOutputWeights()), reasonIfUnsupported, + "Reference Lstm: input and CellToOutputWeights types are mismatched"); + } + if (descriptor.m_ProjectionEnabled) + { + supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_ProjectionWeights()), reasonIfUnsupported, + "Reference Lstm: input and mProjectionWeights types are mismatched"); + if (paramsInfo.m_ProjectionBias != nullptr) + { + supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_ProjectionBias()), reasonIfUnsupported, + "Reference Lstm: input and ProjectionBias types are mismatched"); + } + } + if (descriptor.m_LayerNormEnabled) + { + if (!descriptor.m_CifgEnabled) + { + supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_InputLayerNormWeights()), + reasonIfUnsupported, + "Reference Lstm: input and InputLayerNormWeights types are mismatched"); + } + supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_ForgetLayerNormWeights()), + reasonIfUnsupported, + "Reference Lstm: input and ForgetLayerNormWeights types are mismatched"); + supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_CellLayerNormWeights()), + reasonIfUnsupported, + "Reference Lstm: input and CellLayerNormWeights types are mismatched"); + supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_OutputLayerNormWeights()), + reasonIfUnsupported, + "Reference Lstm: input and OutputLayerNormWeights types are mismatched"); + } return supported; } diff --git a/src/backends/reference/RefLayerSupport.hpp b/src/backends/reference/RefLayerSupport.hpp index ead4d1ce4a..c0bf18824e 100644 --- a/src/backends/reference/RefLayerSupport.hpp +++ b/src/backends/reference/RefLayerSupport.hpp @@ -138,28 +138,8 @@ public: const TensorInfo& cellStateOut, const TensorInfo& output, const LstmDescriptor& descriptor, - const TensorInfo& inputToForgetWeights, - const TensorInfo& inputToCellWeights, - const TensorInfo& inputToOutputWeights, - const TensorInfo& recurrentToForgetWeights, - const TensorInfo& recurrentToCellWeights, - const TensorInfo& recurrentToOutputWeights, - const TensorInfo& forgetGateBias, - const TensorInfo& cellBias, - const TensorInfo& outputGateBias, - const TensorInfo* inputToInputWeights, - const TensorInfo* recurrentToInputWeights, - const TensorInfo* cellToInputWeights, - const TensorInfo* inputGateBias, - const TensorInfo* projectionWeights, - const TensorInfo* projectionBias, - const TensorInfo* cellToForgetWeights, - const TensorInfo* cellToOutputWeights, - Optional<std::string&> reasonIfUnsupported = EmptyOptional(), - const TensorInfo* inputLayerNormWeights = nullptr, - const TensorInfo* forgetLayerNormWeights = nullptr, - const TensorInfo* cellLayerNormWeights = nullptr, - const TensorInfo* outputLayerNormWeights = nullptr) const override; + const LstmInputParamsInfo& paramsInfo, + Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override; bool IsMaximumSupported(const TensorInfo& input0, const TensorInfo& input1, |