diff options
author | Jan Eilers <jan.eilers@arm.com> | 2019-07-03 18:20:40 +0100 |
---|---|---|
committer | Nikhil Raj Arm <nikhil.raj@arm.com> | 2019-07-09 11:22:28 +0000 |
commit | d01a83c8de77c44a938a618918d17385da3baa88 (patch) | |
tree | fca6f5422adfbdcce059049b36d32e0168edcef4 /src/backends/reference/RefLayerSupport.cpp | |
parent | e6eaf661c5b84f4ca051daaf08281d9b8de3fcb9 (diff) | |
download | armnn-d01a83c8de77c44a938a618918d17385da3baa88.tar.gz |
IVGCVSW-3397 Join lstm parameter infos in a struct for isLstmSupported
!android-nn-driver:1461
Change-Id: I9d8fe7adf13832ed0cbcfe98b2353c2f37011d22
Signed-off-by: Jan Eilers <jan.eilers@arm.com>
Diffstat (limited to 'src/backends/reference/RefLayerSupport.cpp')
-rw-r--r-- | src/backends/reference/RefLayerSupport.cpp | 123 |
1 files changed, 74 insertions, 49 deletions
diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp index ac7f310c1d..59c14c4490 100644 --- a/src/backends/reference/RefLayerSupport.cpp +++ b/src/backends/reference/RefLayerSupport.cpp @@ -924,51 +924,11 @@ bool RefLayerSupport::IsLstmSupported(const TensorInfo& input, const TensorInfo& cellStateOut, const TensorInfo& output, const LstmDescriptor& descriptor, - const TensorInfo& inputToForgetWeights, - const TensorInfo& inputToCellWeights, - const TensorInfo& inputToOutputWeights, - const TensorInfo& recurrentToForgetWeights, - const TensorInfo& recurrentToCellWeights, - const TensorInfo& recurrentToOutputWeights, - const TensorInfo& forgetGateBias, - const TensorInfo& cellBias, - const TensorInfo& outputGateBias, - const TensorInfo* inputToInputWeights, - const TensorInfo* recurrentToInputWeights, - const TensorInfo* cellToInputWeights, - const TensorInfo* inputGateBias, - const TensorInfo* projectionWeights, - const TensorInfo* projectionBias, - const TensorInfo* cellToForgetWeights, - const TensorInfo* cellToOutputWeights, - Optional<std::string&> reasonIfUnsupported, - const TensorInfo* inputLayerNormWeights, - const TensorInfo* forgetLayerNormWeights, - const TensorInfo* cellLayerNormWeights, - const TensorInfo* outputLayerNormWeights) const + const LstmInputParamsInfo& paramsInfo, + Optional<std::string&> reasonIfUnsupported) const { ignore_unused(descriptor); - ignore_unused(inputToForgetWeights); - ignore_unused(inputToCellWeights); - ignore_unused(inputToOutputWeights); - ignore_unused(recurrentToForgetWeights); - ignore_unused(recurrentToCellWeights); - ignore_unused(recurrentToOutputWeights); - ignore_unused(forgetGateBias); - ignore_unused(cellBias); - ignore_unused(outputGateBias); - ignore_unused(inputToInputWeights); - ignore_unused(recurrentToInputWeights); - ignore_unused(cellToInputWeights); - ignore_unused(inputGateBias); - ignore_unused(projectionWeights); - ignore_unused(projectionBias); - ignore_unused(cellToForgetWeights); - ignore_unused(cellToOutputWeights); - ignore_unused(inputLayerNormWeights); - ignore_unused(forgetLayerNormWeights); - ignore_unused(cellLayerNormWeights); - ignore_unused(outputLayerNormWeights); + ignore_unused(paramsInfo); bool supported = true; @@ -977,26 +937,91 @@ bool RefLayerSupport::IsLstmSupported(const TensorInfo& input, DataType::QuantisedSymm16 }; + // check inputs and outputs supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported, "Reference Lstm: input is not a supported type."); - supported &= CheckSupportRule(TypesAreEqual(input, outputStateIn), reasonIfUnsupported, "Reference Lstm: input and outputStateIn types are mismatched"); - supported &= CheckSupportRule(TypesAreEqual(input, cellStateIn), reasonIfUnsupported, "Reference Lstm: input and cellStateIn types are mismatched"); - supported &= CheckSupportRule(TypesAreEqual(input, scratchBuffer), reasonIfUnsupported, "Reference Lstm: input and scratchBuffer types are mismatched"); - supported &= CheckSupportRule(TypesAreEqual(input, outputStateOut), reasonIfUnsupported, "Reference Lstm: input and outputStateOut types are mismatched"); - supported &= CheckSupportRule(TypesAreEqual(input, cellStateOut), reasonIfUnsupported, "Reference Lstm: input and cellStateOut types are mismatched"); - supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported, "Reference Lstm: input and output types are mismatched"); + // check layer parameters + supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_InputToForgetWeights()), reasonIfUnsupported, + "Reference Lstm: input and InputToForgetWeights types are mismatched"); + supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_InputToCellWeights()), reasonIfUnsupported, + "Reference Lstm: input and InputToCellWeights types are mismatched"); + supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_InputToOutputWeights()), reasonIfUnsupported, + "Reference Lstm: input and InputToOutputWeights types are mismatched"); + supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_RecurrentToForgetWeights()), reasonIfUnsupported, + "Reference Lstm: input and RecurrentToForgetWeights types are mismatched"); + supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_RecurrentToCellWeights()), reasonIfUnsupported, + "Reference Lstm: input and RecurrentToCellWeights types are mismatched"); + supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_RecurrentToOutputWeights()), reasonIfUnsupported, + "Reference Lstm: input and RecurrentToOutputWeights types are mismatched"); + supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_ForgetGateBias()), reasonIfUnsupported, + "Reference Lstm: input and ForgetGateBias types are mismatched"); + supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_CellBias()), reasonIfUnsupported, + "Reference Lstm: input and CellBias types are mismatched"); + supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_OutputGateBias()), reasonIfUnsupported, + "Reference Lstm: input and OutputGateBias types are mismatched"); + if (!descriptor.m_CifgEnabled) + { + supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_InputToInputWeights()), reasonIfUnsupported, + "Reference Lstm: input and InputToInputWeights types are mismatched"); + supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_RecurrentToInputWeights()), + reasonIfUnsupported, + "Reference Lstm: input and RecurrentToInputWeights types are mismatched"); + supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_InputGateBias()), reasonIfUnsupported, + "Reference Lstm: input and InputGateBias types are mismatched"); + if (descriptor.m_PeepholeEnabled) + { + supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_CellToInputWeights()), + reasonIfUnsupported, + "Reference Lstm: input and CellToInputWeights types are mismatched"); + } + } + if (descriptor.m_PeepholeEnabled) + { + supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_CellToForgetWeights()), reasonIfUnsupported, + "Reference Lstm: input and CellToForgetWeights types are mismatched"); + supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_CellToOutputWeights()), reasonIfUnsupported, + "Reference Lstm: input and CellToOutputWeights types are mismatched"); + } + if (descriptor.m_ProjectionEnabled) + { + supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_ProjectionWeights()), reasonIfUnsupported, + "Reference Lstm: input and mProjectionWeights types are mismatched"); + if (paramsInfo.m_ProjectionBias != nullptr) + { + supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_ProjectionBias()), reasonIfUnsupported, + "Reference Lstm: input and ProjectionBias types are mismatched"); + } + } + if (descriptor.m_LayerNormEnabled) + { + if (!descriptor.m_CifgEnabled) + { + supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_InputLayerNormWeights()), + reasonIfUnsupported, + "Reference Lstm: input and InputLayerNormWeights types are mismatched"); + } + supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_ForgetLayerNormWeights()), + reasonIfUnsupported, + "Reference Lstm: input and ForgetLayerNormWeights types are mismatched"); + supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_CellLayerNormWeights()), + reasonIfUnsupported, + "Reference Lstm: input and CellLayerNormWeights types are mismatched"); + supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_OutputLayerNormWeights()), + reasonIfUnsupported, + "Reference Lstm: input and OutputLayerNormWeights types are mismatched"); + } return supported; } |