aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/RefLayerSupport.cpp
diff options
context:
space:
mode:
authorJan Eilers <jan.eilers@arm.com>2019-07-03 18:20:40 +0100
committerNikhil Raj Arm <nikhil.raj@arm.com>2019-07-09 11:22:28 +0000
commitd01a83c8de77c44a938a618918d17385da3baa88 (patch)
treefca6f5422adfbdcce059049b36d32e0168edcef4 /src/backends/reference/RefLayerSupport.cpp
parente6eaf661c5b84f4ca051daaf08281d9b8de3fcb9 (diff)
downloadarmnn-d01a83c8de77c44a938a618918d17385da3baa88.tar.gz
IVGCVSW-3397 Join lstm parameter infos in a struct for isLstmSupported
!android-nn-driver:1461 Change-Id: I9d8fe7adf13832ed0cbcfe98b2353c2f37011d22 Signed-off-by: Jan Eilers <jan.eilers@arm.com>
Diffstat (limited to 'src/backends/reference/RefLayerSupport.cpp')
-rw-r--r--src/backends/reference/RefLayerSupport.cpp123
1 files changed, 74 insertions, 49 deletions
diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp
index ac7f310c1d..59c14c4490 100644
--- a/src/backends/reference/RefLayerSupport.cpp
+++ b/src/backends/reference/RefLayerSupport.cpp
@@ -924,51 +924,11 @@ bool RefLayerSupport::IsLstmSupported(const TensorInfo& input,
const TensorInfo& cellStateOut,
const TensorInfo& output,
const LstmDescriptor& descriptor,
- const TensorInfo& inputToForgetWeights,
- const TensorInfo& inputToCellWeights,
- const TensorInfo& inputToOutputWeights,
- const TensorInfo& recurrentToForgetWeights,
- const TensorInfo& recurrentToCellWeights,
- const TensorInfo& recurrentToOutputWeights,
- const TensorInfo& forgetGateBias,
- const TensorInfo& cellBias,
- const TensorInfo& outputGateBias,
- const TensorInfo* inputToInputWeights,
- const TensorInfo* recurrentToInputWeights,
- const TensorInfo* cellToInputWeights,
- const TensorInfo* inputGateBias,
- const TensorInfo* projectionWeights,
- const TensorInfo* projectionBias,
- const TensorInfo* cellToForgetWeights,
- const TensorInfo* cellToOutputWeights,
- Optional<std::string&> reasonIfUnsupported,
- const TensorInfo* inputLayerNormWeights,
- const TensorInfo* forgetLayerNormWeights,
- const TensorInfo* cellLayerNormWeights,
- const TensorInfo* outputLayerNormWeights) const
+ const LstmInputParamsInfo& paramsInfo,
+ Optional<std::string&> reasonIfUnsupported) const
{
ignore_unused(descriptor);
- ignore_unused(inputToForgetWeights);
- ignore_unused(inputToCellWeights);
- ignore_unused(inputToOutputWeights);
- ignore_unused(recurrentToForgetWeights);
- ignore_unused(recurrentToCellWeights);
- ignore_unused(recurrentToOutputWeights);
- ignore_unused(forgetGateBias);
- ignore_unused(cellBias);
- ignore_unused(outputGateBias);
- ignore_unused(inputToInputWeights);
- ignore_unused(recurrentToInputWeights);
- ignore_unused(cellToInputWeights);
- ignore_unused(inputGateBias);
- ignore_unused(projectionWeights);
- ignore_unused(projectionBias);
- ignore_unused(cellToForgetWeights);
- ignore_unused(cellToOutputWeights);
- ignore_unused(inputLayerNormWeights);
- ignore_unused(forgetLayerNormWeights);
- ignore_unused(cellLayerNormWeights);
- ignore_unused(outputLayerNormWeights);
+ ignore_unused(paramsInfo);
bool supported = true;
@@ -977,26 +937,91 @@ bool RefLayerSupport::IsLstmSupported(const TensorInfo& input,
DataType::QuantisedSymm16
};
+ // check inputs and outputs
supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
"Reference Lstm: input is not a supported type.");
-
supported &= CheckSupportRule(TypesAreEqual(input, outputStateIn), reasonIfUnsupported,
"Reference Lstm: input and outputStateIn types are mismatched");
-
supported &= CheckSupportRule(TypesAreEqual(input, cellStateIn), reasonIfUnsupported,
"Reference Lstm: input and cellStateIn types are mismatched");
-
supported &= CheckSupportRule(TypesAreEqual(input, scratchBuffer), reasonIfUnsupported,
"Reference Lstm: input and scratchBuffer types are mismatched");
-
supported &= CheckSupportRule(TypesAreEqual(input, outputStateOut), reasonIfUnsupported,
"Reference Lstm: input and outputStateOut types are mismatched");
-
supported &= CheckSupportRule(TypesAreEqual(input, cellStateOut), reasonIfUnsupported,
"Reference Lstm: input and cellStateOut types are mismatched");
-
supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
"Reference Lstm: input and output types are mismatched");
+ // check layer parameters
+ supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_InputToForgetWeights()), reasonIfUnsupported,
+ "Reference Lstm: input and InputToForgetWeights types are mismatched");
+ supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_InputToCellWeights()), reasonIfUnsupported,
+ "Reference Lstm: input and InputToCellWeights types are mismatched");
+ supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_InputToOutputWeights()), reasonIfUnsupported,
+ "Reference Lstm: input and InputToOutputWeights types are mismatched");
+ supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_RecurrentToForgetWeights()), reasonIfUnsupported,
+ "Reference Lstm: input and RecurrentToForgetWeights types are mismatched");
+ supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_RecurrentToCellWeights()), reasonIfUnsupported,
+ "Reference Lstm: input and RecurrentToCellWeights types are mismatched");
+ supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_RecurrentToOutputWeights()), reasonIfUnsupported,
+ "Reference Lstm: input and RecurrentToOutputWeights types are mismatched");
+ supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_ForgetGateBias()), reasonIfUnsupported,
+ "Reference Lstm: input and ForgetGateBias types are mismatched");
+ supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_CellBias()), reasonIfUnsupported,
+ "Reference Lstm: input and CellBias types are mismatched");
+ supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_OutputGateBias()), reasonIfUnsupported,
+ "Reference Lstm: input and OutputGateBias types are mismatched");
+ if (!descriptor.m_CifgEnabled)
+ {
+ supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_InputToInputWeights()), reasonIfUnsupported,
+ "Reference Lstm: input and InputToInputWeights types are mismatched");
+ supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_RecurrentToInputWeights()),
+ reasonIfUnsupported,
+ "Reference Lstm: input and RecurrentToInputWeights types are mismatched");
+ supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_InputGateBias()), reasonIfUnsupported,
+ "Reference Lstm: input and InputGateBias types are mismatched");
+ if (descriptor.m_PeepholeEnabled)
+ {
+ supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_CellToInputWeights()),
+ reasonIfUnsupported,
+ "Reference Lstm: input and CellToInputWeights types are mismatched");
+ }
+ }
+ if (descriptor.m_PeepholeEnabled)
+ {
+ supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_CellToForgetWeights()), reasonIfUnsupported,
+ "Reference Lstm: input and CellToForgetWeights types are mismatched");
+ supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_CellToOutputWeights()), reasonIfUnsupported,
+ "Reference Lstm: input and CellToOutputWeights types are mismatched");
+ }
+ if (descriptor.m_ProjectionEnabled)
+ {
+ supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_ProjectionWeights()), reasonIfUnsupported,
+ "Reference Lstm: input and mProjectionWeights types are mismatched");
+ if (paramsInfo.m_ProjectionBias != nullptr)
+ {
+ supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_ProjectionBias()), reasonIfUnsupported,
+ "Reference Lstm: input and ProjectionBias types are mismatched");
+ }
+ }
+ if (descriptor.m_LayerNormEnabled)
+ {
+ if (!descriptor.m_CifgEnabled)
+ {
+ supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_InputLayerNormWeights()),
+ reasonIfUnsupported,
+ "Reference Lstm: input and InputLayerNormWeights types are mismatched");
+ }
+ supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_ForgetLayerNormWeights()),
+ reasonIfUnsupported,
+ "Reference Lstm: input and ForgetLayerNormWeights types are mismatched");
+ supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_CellLayerNormWeights()),
+ reasonIfUnsupported,
+ "Reference Lstm: input and CellLayerNormWeights types are mismatched");
+ supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_OutputLayerNormWeights()),
+ reasonIfUnsupported,
+ "Reference Lstm: input and OutputLayerNormWeights types are mismatched");
+ }
return supported;
}